@@ -588,3 +588,66 @@ def _multi_conv(self, x, **kwargs):
588588 """Applies the convolutions with different sizes and concatenates outputs."""
589589
590590 return tf .concat ([conv (x , ** kwargs ) for conv in self .convs ], axis = - 1 )
591+
592+
593+ class ConfigurableMLP (tf .keras .Model ):
594+ """Implements a simple configurable MLP with optional residual connections and dropout."""
595+
596+ def __init__ (
597+ self , input_dim , hidden_dim = 512 , num_hidden = 2 , activation = "relu" , residual = True , dropout_rate = 0.05 , ** kwargs
598+ ):
599+ """
600+ Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
601+
602+ Parameters:
603+ -----------
604+ input_dim : int
605+ The input dimensionality
606+ hidden_dim : int, optional, default: 512
607+ The dimensionality of the hidden layers
608+ num_hidden : int, optional, default: 2
609+ The number of hidden layers (minimum: 1)
610+ activation : string, optional, default: 'relu'
611+ The activation function of the dense layers
612+ residual : bool, optional, default: True
613+ Use residual connections in the MLP
614+ dropout_rate : float, optional, default: 0.05
615+ Dropout rate for the hidden layers in the MLP
616+ """
617+
618+ super ().__init__ (** kwargs )
619+
620+ self .input_dim = input_dim
621+ self .model = tf .keras .Sequential (
622+ [tf .keras .layers .Dense (hidden_dim , activation = activation ), tf .keras .layers .Dropout (dropout_rate )]
623+ )
624+ for _ in range (num_hidden ):
625+ self .model .add (
626+ ConfigurableHiddenBlock (
627+ hidden_dim ,
628+ activation = activation ,
629+ residual = residual ,
630+ dropout_rate = dropout_rate ,
631+ )
632+ )
633+ self .model .add (tf .keras .layers .Dense (input_dim ))
634+
635+ def call (self , inputs , ** kwargs ):
636+ return self .model (inputs , ** kwargs )
637+
638+
639+ class ConfigurableHiddenBlock (tf .keras .Model ):
640+ def __init__ (self , num_units , activation = "relu" , residual = True , dropout_rate = 0.0 ):
641+ super ().__init__ ()
642+
643+ self .act_fn = tf .keras .activations .get (activation )
644+ self .residual = residual
645+ self .dense_with_dropout = tf .keras .Sequential (
646+ [tf .keras .layers .Dense (num_units , activation = None ), tf .keras .layers .Dropout (dropout_rate )]
647+ )
648+
649+ def call (self , inputs , ** kwargs ):
650+ x = self .dense_with_dropout (inputs , ** kwargs )
651+ if self .residual :
652+ x = x + inputs
653+ return self .act_fn (x )
0 commit comments