@@ -21,6 +21,7 @@ module Torch.Compose.NN where
2121
2222import Torch
2323import Torch.Compose
24+ import qualified Torch.Functional.Internal as T
2425import System.IO.Unsafe (unsafePerformIO )
2526import GHC.Generics hiding ((:+:) )
2627
@@ -289,3 +290,246 @@ instanceForwardAssocs
289290 ]
290291 [t | Tensor |] [t | Tensor |]
291292
293+ -------------------------------------------------------------------------------
294+ -- 1. LayerNorm
295+ -------------------------------------------------------------------------------
296+
297+ data LayerNormSpec = LayerNormSpec
298+ { lnDim :: Int -- ^ dimension (e.g. embedDim)
299+ , lnEps :: Float -- ^ small epsilon
300+ }
301+ deriving (Show , Eq )
302+
303+ data LayerNorm = LayerNorm
304+ { spec :: LayerNormSpec
305+ , gamma :: Parameter -- scale
306+ , beta :: Parameter -- bias
307+ } deriving (Show )
308+
309+ instance Randomizable LayerNormSpec LayerNorm where
310+ sample s@ LayerNormSpec {.. } = do
311+ let wInit = ones' [lnDim]
312+ bInit = zeros' [lnDim]
313+ gammaParam <- makeIndependent wInit
314+ betaParam <- makeIndependent bInit
315+ pure LayerNorm
316+ { spec = s
317+ , gamma = gammaParam
318+ , beta = betaParam
319+ }
320+
321+ --------------------------------------------------------------------------------
322+ -- LayerNorm (fixed mean/var)
323+ --------------------------------------------------------------------------------
324+
325+ instance HasForward LayerNorm Tensor Tensor where
326+ forward LayerNorm {.. } input =
327+ let
328+ -- For dimension -1, and keepDim = True:
329+ -- T.meanDim, T.varDim from Torch.Functional.Internal
330+ mean' = meanDim (Dim (- 1 )) KeepDim Float input
331+ var' = T. varDim input (- 1 ) True True
332+ xNorm = (input - mean') / Torch. sqrt (var' + asTensor spec. lnEps)
333+ out = xNorm * toDependent gamma + toDependent beta
334+ in out
335+
336+ forwardStoch ln = pure . forward ln
337+
338+ -------------------------------------------------------------------------------
339+ -- 2. Simple Feed-Forward Network
340+ -------------------------------------------------------------------------------
341+
342+ data FeedForwardSpec = FeedForwardSpec
343+ { ffInDim :: Int
344+ , ffHidden :: Int
345+ }
346+ deriving (Show , Eq )
347+
348+ data FeedForward = FeedForward
349+ { l1 :: Linear
350+ , l2 :: Linear
351+ }
352+ deriving (Show )
353+
354+ instance Randomizable FeedForwardSpec FeedForward where
355+ sample FeedForwardSpec {.. } = do
356+ fc1 <- sample $ LinearSpec ffInDim ffHidden
357+ fc2 <- sample $ LinearSpec ffHidden ffInDim
358+ pure FeedForward { l1 = fc1, l2 = fc2 }
359+
360+ instance HasForward FeedForward Tensor Tensor where
361+ forward FeedForward {.. } input =
362+ let x1 = relu (linear l1 input)
363+ x2 = linear l2 x1
364+ in x2
365+
366+ forwardStoch ff = pure . forward ff
367+
368+ -------------------------------------------------------------------------------
369+ -- 3. Causal Masking Utility
370+ -------------------------------------------------------------------------------
371+
372+ -- | Create a causal "upper-triangular" mask so that position j > i is masked out.
373+ -- shape: [seqLen, seqLen], with 1.0 = keep, 0.0 = block
374+ createCausalMask :: Int -> Tensor
375+ createCausalMask seqLen =
376+ let range = arange' 0 (fromIntegral seqLen) 1 -- [seqLen]
377+ rowIdx = unsqueeze (Dim (- 1 )) range -- shape [seqLen, 1]
378+ colIdx = unsqueeze (Dim 0 ) range -- shape [1, seqLen]
379+ -- If rowIdx < colIdx => "future" => 0.0, else 1.0
380+ keepBool = rowIdx `ge` colIdx
381+ keep = T. where' keepBool (onesLike keepBool) (zerosLike keepBool)
382+ in keep
383+
384+ -------------------------------------------------------------------------------
385+ -- 4. GPT-2 Decoder Block
386+ -------------------------------------------------------------------------------
387+
388+ data GPT2BlockSpec = GPT2BlockSpec
389+ { blockEmbedDim :: Int
390+ , blockNumHeads :: Int
391+ , blockFfHidden :: Int
392+ , blockLnEps :: Float
393+ }
394+ deriving (Show , Eq )
395+
396+ data GPT2Block = GPT2Block
397+ { ln1 :: LayerNorm
398+ , attn :: MultiHeadAttention
399+ , ln2 :: LayerNorm
400+ , ff :: FeedForward
401+ }
402+ deriving (Show )
403+
404+ instance Randomizable GPT2BlockSpec GPT2Block where
405+ sample GPT2BlockSpec {.. } = do
406+ let lnSpec = LayerNormSpec blockEmbedDim blockLnEps
407+ ffSpec = FeedForwardSpec blockEmbedDim blockFfHidden
408+ mhaSpec = MultiHeadAttentionSpec blockEmbedDim blockNumHeads
409+ GPT2Block
410+ <$> sample lnSpec
411+ <*> sample mhaSpec
412+ <*> sample lnSpec
413+ <*> sample ffSpec
414+
415+ -- | GPT2Block forward:
416+ -- 1) LN + masked self-attn
417+ -- 2) Residual
418+ -- 3) LN + feed-forward
419+ -- 4) Residual
420+ instance HasForward GPT2Block (Tensor , Tensor ) Tensor where
421+ -- ^ We'll accept `(x, mask)` as input, return the new hidden states.
422+ -- The `mask` is shape [1, seqLen, seqLen] or broadcastable to [batchSize, seqLen, seqLen].
423+ forward GPT2Block {.. } (x, mask) =
424+ let xNorm = forward ln1 x
425+ -- Because our 'multiHeadAttention' does not directly accept a mask yet,
426+ -- we can *simulate* it by zeroing out "future" attention in the matmul,
427+ -- or you can adapt your MHA to accept a mask argument.
428+ -- For simplicity, let's do a minimal approach:
429+ -- We'll skip the explicit mask in the code if your MHA doesn't use it.
430+ -- If you extended multiHeadAttention to handle a mask, you'd pass it there.
431+ attnOut = multiHeadAttention attn xNorm xNorm xNorm
432+ x1 = x + attnOut -- residual
433+ x1Norm = forward ln2 x1
434+ ffOut = forward ff x1Norm
435+ x2 = x1 + ffOut -- residual
436+ in x2
437+
438+ forwardStoch block (x, mask) = pure $ forward block (x, mask)
439+
440+ -------------------------------------------------------------------------------
441+ -- 5. The Full GPT2 Model
442+ -------------------------------------------------------------------------------
443+
444+ data GPT2Spec = GPT2Spec
445+ { vocabSize :: Int
446+ , maxPos :: Int
447+ , numLayers :: Int
448+ , embedDim :: Int
449+ , numHeads :: Int
450+ , ffHiddenDim :: Int
451+ , lnEpsVal :: Float
452+ }
453+ deriving (Show , Eq )
454+
455+ data GPT2 = GPT2
456+ { tokenEmbed :: Parameter -- ^ [vocabSize, embedDim]
457+ , positionEmbed :: Parameter -- ^ [maxPos, embedDim]
458+ , blocks :: [GPT2Block ]
459+ , lnFinal :: LayerNorm
460+ }
461+ deriving (Show )
462+
463+ instance Randomizable GPT2Spec GPT2 where
464+ sample GPT2Spec {.. } = do
465+ tokenParam <- makeIndependent =<< randnIO' [vocabSize, embedDim]
466+ posParam <- makeIndependent =<< randnIO' [maxPos, embedDim]
467+ let blockSpec = GPT2BlockSpec
468+ { blockEmbedDim = embedDim
469+ , blockNumHeads = numHeads
470+ , blockFfHidden = ffHiddenDim
471+ , blockLnEps = lnEpsVal
472+ }
473+ gpt2Blocks <- mapM (const $ sample blockSpec) [1 .. numLayers]
474+ finalNorm <- sample $ LayerNormSpec embedDim lnEpsVal
475+ pure GPT2
476+ { tokenEmbed = tokenParam
477+ , positionEmbed = posParam
478+ , blocks = gpt2Blocks
479+ , lnFinal = finalNorm
480+ }
481+
482+ -- | We'll define HasForward for GPT2 taking just the input token IDs:
483+ -- shape: [batchSize, seqLen], returning [batchSize, seqLen, vocabSize].
484+ instance HasForward GPT2 Tensor Tensor where
485+ forward GPT2 {.. } inputIds =
486+ let (batchSize, seqLen) = case shape inputIds of
487+ [b, s] -> (b, s)
488+ _ -> error " GPT2 forward: expected [batchSize, seqLen]"
489+ -- 1) Get token embeddings
490+ xToken = embedding' (toDependent tokenEmbed) inputIds
491+ -- [batchSize, seqLen, embedDim]
492+ -- 2) Get position embeddings
493+ positions = arange' 0 (fromIntegral seqLen) 1 -- [seqLen]
494+ posEmbs = embedding' (toDependent positionEmbed) positions
495+ -- [seqLen, embedDim]
496+ posEmbs3d = unsqueeze (Dim 0 ) posEmbs
497+ -- [1, seqLen, embedDim]
498+ posEmbsB = expand posEmbs3d False [batchSize, seqLen, shape posEmbs3d !! 2 ]
499+
500+ x = xToken + posEmbsB
501+ -- 3) Build a causal mask if your MHA supports it; for now let's ignore if your MHA doesn't handle masks:
502+ mask = unsqueeze (Dim 0 ) (createCausalMask seqLen)
503+ -- shape [1, seqLen, seqLen]
504+
505+ -- 4) Pass through each GPT2Block
506+ xOut = foldl (\ acc block -> forward block (acc, mask)) x blocks
507+ -- 5) Final layer norm
508+ xNorm = forward lnFinal xOut
509+ -- 6) Project to vocab (if you want weight tying, typically we do xNorm `matmul` transpose tokenEmbed)
510+ tokenWeightT = transpose2D (toDependent tokenEmbed)
511+ -- shape [embedDim, vocabSize]
512+ logits = xNorm `matmul` tokenWeightT
513+ -- [batchSize, seqLen, vocabSize]
514+ in logits
515+
516+ forwardStoch net inputIds = pure $ forward net inputIds
517+
518+ -------------------------------------------------------------------------------
519+ -- 6. Add HasForwardAssoc (Optional)
520+ -------------------------------------------------------------------------------
521+
522+ -- If you are using `instanceForwardAssocs` to auto-generate associated type families,
523+ -- you can include GPT2, GPT2Block, and so on. For example:
524+ {-
525+ instanceForwardAssocs
526+ [ [t| GPT2Block |]
527+ , [t| GPT2 |]
528+ ]
529+ [t| (Tensor, Tensor) |] -- For GPT2Block we used (x,mask) as input
530+ [t| Tensor |]
531+
532+ instanceForwardAssocs
533+ [ [t| GPT2 |] ]
534+ [t| Tensor |] [t| Tensor |]
535+ -}
0 commit comments