@@ -171,6 +171,111 @@ instance Randomizable Conv2dSpec' Conv2d' where
171171 , params = a
172172 }
173173
174+
175+ --------------------------------------------------------------------------------
176+ -- Multi-Head Attention Data Structures
177+ --------------------------------------------------------------------------------
178+
179+ -- | Specification for initializing a MultiHeadAttention module.
180+ data MultiHeadAttentionSpec = MultiHeadAttentionSpec
181+ { mhaEmbedDim :: Int -- ^ Model embedding dimension
182+ , mhaNumHeads :: Int -- ^ Number of attention heads
183+ } deriving (Show , Eq )
184+
185+ -- | Data type that holds parameters for Multi-Head Attention.
186+ data MultiHeadAttention = MultiHeadAttention
187+ { wQ :: Linear -- ^ Linear projection for the queries
188+ , wK :: Linear -- ^ Linear projection for the keys
189+ , wV :: Linear -- ^ Linear projection for the values
190+ , wO :: Linear -- ^ Final linear projection after combining heads
191+ , headDim :: Int -- ^ Dimension per head = embedDim / numHeads
192+ , nHeads :: Int -- ^ Number of attention heads
193+ } deriving (Show )
194+
195+ -- | Create random parameters for Multi-Head Attention given the specification.
196+ instance Randomizable MultiHeadAttentionSpec MultiHeadAttention where
197+ sample MultiHeadAttentionSpec {.. } = do
198+ let headDim = mhaEmbedDim `Prelude.div` mhaNumHeads
199+ wQ' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
200+ wK' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
201+ wV' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
202+ wO' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
203+ return $ MultiHeadAttention
204+ { wQ = wQ'
205+ , wK = wK'
206+ , wV = wV'
207+ , wO = wO'
208+ , headDim = headDim
209+ , nHeads = mhaNumHeads
210+ }
211+
212+ --------------------------------------------------------------------------------
213+ -- Forward Pass (Scaled Dot-Product Attention + Multi-Head Logic)
214+ --------------------------------------------------------------------------------
215+
216+ -- | Compute scaled dot-product attention for query, key, value tensors.
217+ -- The typical shape for q, k, v is:
218+ -- [batchSize, numHeads, seqLen, headDim]
219+ --
220+ -- Returns: [batchSize, numHeads, seqLen, headDim]
221+ scaledDotProductAttention
222+ :: Tensor -- ^ Queries (q)
223+ -> Tensor -- ^ Keys (k)
224+ -> Tensor -- ^ Values (v)
225+ -> Tensor -- ^ Output (contextual embeddings)
226+ scaledDotProductAttention q k v =
227+ let -- q*k^T -> [batchSize, numHeads, seqLen, seqLen]
228+ dk = fromIntegral (shape q !! 3 ) -- headDim
229+ scores = (q `matmul` transpose2D k) / Torch. sqrt (asTensor (dk :: Float ))
230+ attnWeights = softmax (Dim (- 1 )) scores -- softmax over last dim (seqLen)
231+ output = attnWeights `matmul` v -- multiply by values
232+ in output
233+
234+ -- | Forward pass for Multi-Head Attention (without any mask or dropout, minimal).
235+ multiHeadAttention
236+ :: MultiHeadAttention
237+ -> Tensor -- ^ Input queries [batchSize, seqLen, embedDim]
238+ -> Tensor -- ^ Input keys [batchSize, seqLen, embedDim]
239+ -> Tensor -- ^ Input values [batchSize, seqLen, embedDim]
240+ -> Tensor -- ^ Output [batchSize, seqLen, embedDim]
241+ multiHeadAttention MultiHeadAttention {.. } queries keys values =
242+ let
243+ -- Project inputs to Q, K, V space
244+ q = linear wQ queries -- [batchSize, seqLen, embedDim]
245+ k = linear wK keys -- [batchSize, seqLen, embedDim]
246+ v = linear wV values -- [batchSize, seqLen, embedDim]
247+
248+ batchSize = shape queries !! 0
249+ seqLen = shape queries !! 1
250+
251+ -- Reshape for multi-head: [batchSize, seqLen, nHeads*headDim]
252+ -- -> [batchSize, seqLen, nHeads, headDim]
253+ -- -> [batchSize, nHeads, seqLen, headDim]
254+ reshapeForHeads t =
255+ let t' = view [batchSize, seqLen, nHeads* headDim] t
256+ t''= view [batchSize, seqLen, nHeads, headDim] t'
257+ in permute [0 ,2 ,1 ,3 ] t'' -- reorder dimensions to [batchSize, nHeads, seqLen, headDim]
258+
259+ qHeads = reshapeForHeads q
260+ kHeads = reshapeForHeads k
261+ vHeads = reshapeForHeads v
262+
263+ -- Apply scaled dot-product attention
264+ attnOutput = scaledDotProductAttention qHeads kHeads vHeads
265+ -- shape: [batchSize, nHeads, seqLen, headDim]
266+
267+ -- Convert back: [batchSize, nHeads, seqLen, headDim]
268+ -- -> [batchSize, seqLen, nHeads, headDim]
269+ -- -> [batchSize, seqLen, nHeads*headDim]
270+ attnOutputTrans = permute [0 ,2 ,1 ,3 ] attnOutput
271+ combinedHeads = view [batchSize, seqLen, nHeads* headDim] attnOutputTrans
272+
273+ -- Final linear
274+ out = linear wO combinedHeads -- [batchSize, seqLen, embedDim]
275+ in out
276+
277+
278+
174279-- Generate HasForwardAssoc instances from HasForward instances.
175280instanceForwardAssocs
176281 [ [t | Linear |]
0 commit comments