HurricaneNet / custom_objects.py
sharktide's picture
Create custom_objects.py
b3d42fc verified
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras.saving import register_keras_serializable
from tensorflow.keras import layers, models, backend as K
@register_keras_serializable()
class SSTAmplifier(tf.keras.layers.Layer):
def __init__(self, threshold=28.0, scale=0.1, **kwargs):
super().__init__(**kwargs)
self.threshold = threshold
self.scale = scale
def call(self, inputs):
sst = inputs[:, 0]
factor = tf.sigmoid((sst - self.threshold) * self.scale)
mod = 1.0 + 0.3 * factor
return tf.expand_dims(mod, -1)
@register_keras_serializable()
class ShearSuppressor(tf.keras.layers.Layer):
def __init__(self, threshold=14.0, scale=0.2, **kwargs):
super().__init__(**kwargs)
self.threshold = threshold
self.scale = scale
def call(self, inputs):
shear = inputs[:, 3]
suppress = tf.sigmoid((self.threshold - shear) * self.scale)
mod = 1.0 - 0.25 * suppress
return tf.expand_dims(mod, -1)
@register_keras_serializable()
class VorticityActivator(tf.keras.layers.Layer):
def __init__(self, threshold=1.2, scale=1.0, **kwargs):
super().__init__(**kwargs)
self.threshold = threshold
self.scale = scale
def call(self, inputs):
vort = inputs[:, 4]
activate = tf.sigmoid((vort - self.threshold) * self.scale)
mod = 1.0 + 0.2 * activate
return tf.expand_dims(mod, -1)
@register_keras_serializable()
class ModulationMixer(tf.keras.layers.Layer):
def call(self, inputs):
sst_mod, shear_mod, vort_mod = inputs
product = sst_mod * shear_mod * vort_mod
smooth = 1.0 + 0.25 * tf.tanh(product - 1.0)
return smooth
CUSTOM_OBJECTS = {
'ModulationMixer': ModulationMixer,
'VorticityActivator': VorticityActivator,
'ShearSuppressor': ShearSuppressor,
'SSTAmplifier': SSTAmplifier
}