@@ -88,67 +88,132 @@ def __call__(
8888
8989
9090################################################################################
91- # MARK: SwiGLU
91+ # MARK: LinearSwiGLU
9292################################################################################
9393
9494
95- class SwiGLU (nn .Module ):
96- """SwiGLU feed-forward network .
95+ class LinearSwiGLU (nn .Module ):
96+ """A Dense layer variant that outputs SwiGLU gating directly .
9797
9898 A gated feed-forward network using SiLU (Swish) activation for the gate,
9999 following "GLU Variants Improve Transformer" (Shazeer, 2020):
100100 https://arxiv.org/abs/2002.05202
101101
102- The forward pass is:
103-
104- gate_and_val = x @ W_up # (*, hidden_size) -> (*, ff_size * 2)
105- val, gate = split(gate_and_val) # (*, ff_size) each
106- x = val * SiLU(gate) # (*, ff_size)
107- x = dropout(x)
108- x = x @ W_down # (*, ff_size) -> (*, hidden_size)
109-
110- Attributes:
111- hidden_size: Output dimension (residual stream width).
112- ff_size: Intermediate dimension (before gating).
113- zero_init_output: If True, the down-projection kernel is initialized to
114- zeros so the block starts as identity.
115- dropout_rate: Dropout rate applied after gating.
116- dtype: Data type for computation.
102+ Projects the input dimension to features * 2, chunks the result across the
103+ last dimension, and gates the activation channel with SiLU.
117104 """
118105
119- hidden_size : int
120- ff_size : int
121- zero_init_output : bool = False
122- dropout_rate : float = 0.0
106+ features : int
107+ use_bias : bool = False
123108 dtype : DType = jnp .float32
124109
125110 @nn .compact
126111 @kt .typechecked
127- def __call__ (
128- self , x : Float ['batch *other_dims hidden_size' ], * , is_training : bool
129- ) -> Float ['batch *other_dims hidden_size' ]:
130- # Up-projection: (*, hidden_size) -> (*, ff_size * 2).
112+ def __call__ (self , x : Float ["*batch d_in" ]) -> Float ["*batch features" ]:
113+ # Project to double feature width
131114 gate_and_val = nn .Dense (
132- features = self .ff_size * 2 ,
133- use_bias = False ,
115+ features = self .features * 2 ,
116+ use_bias = self . use_bias ,
134117 dtype = self .dtype ,
135- name = 'Dense_Up' ,
118+ name = "Dense_Gate_Val" ,
136119 )(x )
137- # Split into value and gate, apply SiLU gating.
120+
121+ # Split and apply SiLU gating (mirrors torch.chunk(2, dim=-1))
138122 val , gate = jnp .split (gate_and_val , 2 , axis = - 1 )
139- x = val * nn .silu (gate )
140- x = nn .Dropout (rate = self .dropout_rate , deterministic = not is_training )(x )
141- # Down-projection: (*, ff_size) -> (*, hidden_size).
123+ return val * nn .silu (gate )
124+
125+
126+ ################################################################################
127+ # MARK: FeedForward Unified Block
128+ ################################################################################
129+
130+
131+ class FeedForward (nn .Module ):
132+ """A unified FeedForward block selecting between SwiGLU or traditional layers.
133+
134+ Attributes:
135+ output_size: Output dimension (residual stream width).
136+ hidden_size: Intermediate bottleneck network dimension.
137+ ffn_type: Layout type toggle. - 'swiglu' uses a gated SwiGLU projection
138+ layer. - 'standard' uses a classic dense projection followed by an
139+ activation.
140+ activation: Name of the activation function to use when
141+ `ffn_type='standard'` (e.g., 'gelu', 'silu', 'relu'). This parameter is
142+ explicitly ignored when `ffn_type='swiglu'` because the SwiGLU path uses
143+ its own mathematical gating mechanism (SiLU/Swish).
144+ zero_init_output: If True, the terminal linear projections are zeroed out
145+ ensuring the block satisfies identity-at-init behavior.
146+ dropout_rate: Activation state dropout regularization coefficient.
147+ dtype: Numerical precision layout representation format.
148+ """
149+
150+ output_size : int
151+ hidden_size : int
152+ ffn_type : str = "standard"
153+ activation : str = "gelu"
154+ zero_init_output : bool = False
155+ dropout_rate : float = 0.0
156+ dtype : DType = jnp .float32
157+
158+ def setup (self ):
159+ if self .ffn_type not in ("standard" , "swiglu" ):
160+ raise ValueError (
161+ f"Unknown ffn_type: { self .ffn_type } . Must be 'standard' or 'swiglu'."
162+ )
163+ # Regularization Dropout Layer
164+ self .dropout = nn .Dropout (rate = self .dropout_rate )
165+
166+ # Down Projection Layer Config
142167 down_kernel_init = (
143168 nn .initializers .zeros_init ()
144169 if self .zero_init_output
145170 else nn .initializers .lecun_normal ()
146171 )
147- x = nn .Dense (
148- features = self .hidden_size ,
149- use_bias = False ,
150- dtype = self .dtype ,
172+ # Standard SwiGLU down-projections generally omit biases
173+ self .use_down_bias = False if self .ffn_type == "swiglu" else True
174+
175+ self .down_proj = nn .Dense (
176+ features = self .output_size ,
177+ use_bias = self .use_down_bias ,
151178 kernel_init = down_kernel_init ,
152- name = 'Dense_Down' ,
153- )(x )
179+ dtype = self .dtype ,
180+ name = "Dense_Down" ,
181+ )
182+
183+ @nn .compact
184+ @kt .typechecked
185+ def __call__ (
186+ self , x : Float ["batch *other_dims output_size" ], * , is_training : bool
187+ ) -> Float ["batch *other_dims output_size" ]:
188+ # Up-projection step
189+ if self .ffn_type == "swiglu" :
190+ # Project to double feature width
191+ gate_and_val = nn .Dense (
192+ features = self .hidden_size * 2 ,
193+ use_bias = False ,
194+ dtype = self .dtype ,
195+ name = "Dense_Up" ,
196+ )(x )
197+ # Split and apply SiLU gating
198+ val , gate = jnp .split (gate_and_val , 2 , axis = - 1 )
199+ x = val * nn .silu (gate )
200+ elif self .ffn_type == "standard" :
201+ x = nn .Dense (
202+ features = self .hidden_size ,
203+ use_bias = True ,
204+ dtype = self .dtype ,
205+ name = "Dense_Up" ,
206+ )(x )
207+ # Apply the configured activation function
208+ activation_fn = getattr (nn , self .activation )
209+ x = activation_fn (x )
210+ else :
211+ raise ValueError (f"Unknown ffn_type mapping strategy: { self .ffn_type !r} " )
212+
213+ # Middle regularization step
214+ x = self .dropout (x , deterministic = not is_training )
215+
216+ # Final down-projection step
217+ x = self .down_proj (x )
218+
154219 return x
0 commit comments