sharktide commited on
Commit
b3d42fc
·
verified ·
1 Parent(s): 9909420

Create custom_objects.py

Browse files
Files changed (1) hide show
  1. custom_objects.py +58 -0
custom_objects.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras.models import load_model
2
+ import tensorflow as tf
3
+ from tensorflow.keras.saving import register_keras_serializable
4
+ from tensorflow.keras import layers, models, backend as K
5
+
6
+ @register_keras_serializable()
7
+ class SSTAmplifier(tf.keras.layers.Layer):
8
+ def __init__(self, threshold=28.0, scale=0.1, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.threshold = threshold
11
+ self.scale = scale
12
+
13
+ def call(self, inputs):
14
+ sst = inputs[:, 0]
15
+ factor = tf.sigmoid((sst - self.threshold) * self.scale)
16
+ mod = 1.0 + 0.3 * factor
17
+ return tf.expand_dims(mod, -1)
18
+
19
+ @register_keras_serializable()
20
+ class ShearSuppressor(tf.keras.layers.Layer):
21
+ def __init__(self, threshold=14.0, scale=0.2, **kwargs):
22
+ super().__init__(**kwargs)
23
+ self.threshold = threshold
24
+ self.scale = scale
25
+
26
+ def call(self, inputs):
27
+ shear = inputs[:, 3]
28
+ suppress = tf.sigmoid((self.threshold - shear) * self.scale)
29
+ mod = 1.0 - 0.25 * suppress
30
+ return tf.expand_dims(mod, -1)
31
+
32
+ @register_keras_serializable()
33
+ class VorticityActivator(tf.keras.layers.Layer):
34
+ def __init__(self, threshold=1.2, scale=1.0, **kwargs):
35
+ super().__init__(**kwargs)
36
+ self.threshold = threshold
37
+ self.scale = scale
38
+
39
+ def call(self, inputs):
40
+ vort = inputs[:, 4]
41
+ activate = tf.sigmoid((vort - self.threshold) * self.scale)
42
+ mod = 1.0 + 0.2 * activate
43
+ return tf.expand_dims(mod, -1)
44
+
45
+ @register_keras_serializable()
46
+ class ModulationMixer(tf.keras.layers.Layer):
47
+ def call(self, inputs):
48
+ sst_mod, shear_mod, vort_mod = inputs
49
+ product = sst_mod * shear_mod * vort_mod
50
+ smooth = 1.0 + 0.25 * tf.tanh(product - 1.0)
51
+ return smooth
52
+
53
+ CUSTOM_OBJECTS = {
54
+ 'ModulationMixer': ModulationMixer,
55
+ 'VorticityActivator': VorticityActivator,
56
+ 'ShearSuppressor': ShearSuppressor,
57
+ 'SSTAmplifier': SSTAmplifier
58
+ }