Skip to content

Commit 44733ab

Browse files
Add splitLayers
1 parent d7220e6 commit 44733ab

File tree

3 files changed

+51
-22
lines changed

3 files changed

+51
-22
lines changed

src/Torch/Compose.hs

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
{-# LANGUAGE TypeFamilyDependencies#-}
2121
{-# LANGUAGE TypeOperators#-}
2222
{-# LANGUAGE UndecidableInstances#-}
23+
{-# LANGUAGE AllowAmbiguousTypes#-}
24+
2325

2426
module Torch.Compose where
2527

@@ -31,6 +33,7 @@ import Data.Coerce
3133
import Control.Exception
3234
import System.IO.Unsafe
3335
import qualified Language.Haskell.TH as TH
36+
import GHC.TypeLits
3437
-- import qualified Language.Haskell.TH.Syntax as TH
3538

3639

@@ -144,15 +147,34 @@ instance (HasForwardAssoc f a) => HasForwardAssoc f (HList '[a]) where
144147
forwardAssoc f (HCons a HNil) = toHList $ forwardAssoc f a
145148
forwardStochAssoc f (HCons a HNil) = toHList <$> forwardStochAssoc f a
146149

150+
type family ToHNat (x :: Nat) :: HNat where
151+
ToHNat 0 = HZero
152+
ToHNat x = HSucc (ToHNat ( x - 1 ))
153+
154+
dropLayers :: forall num a ys xs. (KnownNat num, Coercible a (HList xs), HDrop (ToHNat num) xs ys, HLengthGe xs (ToHNat num)) => a -> HList ys
155+
dropLayers m = hDrop (Proxy :: Proxy (ToHNat num)) (coerce m)
156+
157+
takeLayers :: forall num a ys xs. (KnownNat num, Coercible a (HList xs), HTake (ToHNat num) xs ys, SameLength' (HReplicateR (ToHNat num) ()) ys, HLengthEq1 ys (ToHNat num), HLengthEq2 ys (ToHNat num), HLengthGe xs (ToHNat num)) => a -> HList ys
158+
takeLayers m = hTake (Proxy :: Proxy (ToHNat num)) (coerce m)
159+
160+
splitLayers :: forall num a xs ys xsys. (KnownNat num, Coercible a (HList xsys), HSplitAt1 '[] (ToHNat num) xsys xs ys, HAppendList1 xs ys xsys, SameLength' (HReplicateR (ToHNat num) ()) xs, HLengthEq1 xs (ToHNat num), HLengthEq2 xs (ToHNat num)) => a -> (HList xs, HList ys)
161+
splitLayers m = hSplitAt (Proxy :: Proxy (ToHNat num)) (coerce m)
162+
147163
dropLastLayer :: (Coercible a (HList xs1), HRevApp xs2 '[x] xs1, HRevApp xs2 '[] sx, HRevApp xs1 '[] (x : xs2), HRevApp sx '[] xs2) => a -> HList sx
148164
dropLastLayer m = hReverse (hDrop (Proxy :: Proxy (HSucc HZero)) (hReverse (coerce m)))
149165

150166
addLastLayer :: HAppend l1 (HList '[e]) => l1 -> e -> HAppendR l1 (HList '[e])
151167
addLastLayer a b = a `hAppend` (b .*. HNil)
152168

169+
dropFirstLayer :: Coercible a (HList (x : ys)) => a -> HList ys
170+
dropFirstLayer m = hDrop (Proxy :: Proxy (HSucc HZero)) (coerce m)
171+
153172
getLastLayer :: (Coercible a (HList l1), HRevApp l1 '[] (e : l)) => a -> e
154173
getLastLayer a = hLast (coerce a)
155174

175+
getFirstLayer :: Coercible a (HList (e : l)) => a -> e
176+
getFirstLayer a = hHead (coerce a)
177+
156178
hScanl :: forall f z ls xs1 sx xs2. (HScanr f z ls xs1, HRevApp xs1 '[] sx, HRevApp sx '[] xs1, HRevApp xs2 '[] ls, HRevApp ls '[] xs2) => f -> z -> HList xs2 -> HList sx
157179
hScanl a b c = hReverse $ hScanr a b (hReverse c)
158180

@@ -259,12 +281,5 @@ instanceForwardAssoc model input output =
259281
|]
260282

261283
instanceForwardAssocs :: [TH.Q TH.Type] -> TH.Q TH.Type -> TH.Q TH.Type -> TH.DecsQ
262-
instanceForwardAssocs models input output = do
263-
decs <- forM models $ \model ->
264-
[d|
265-
instance HasForwardAssoc $model $input where
266-
type ForwardResult $model $input = $output
267-
forwardAssoc = forward
268-
forwardStochAssoc = forwardStoch
269-
|]
270-
return $ concat decs
284+
instanceForwardAssocs models input output =
285+
concat <$> forM models (\model -> instanceForwardAssoc model input output)

stack.yaml

Lines changed: 0 additions & 13 deletions
This file was deleted.

test/Spec.hs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,30 @@ main = hspec $ do
121121
zero1 = over (types @Tensor) (ones' . shape) $ getLastLayer m1'
122122
model' = addLastLayer (dropLastLayer m0) (mergeParameters (+) layer0 zero1)
123123
return ()
124+
it "Fanin" $ do
125+
m0 <- sample mlpSpec
126+
m1 <- sample mlpSpec
127+
let l0 = getFirstLayer m0
128+
l1 = getFirstLayer m1
129+
fin = Fanin l0 l1
130+
model' = HCons fin $ dropFirstLayer m0
131+
input = ones' [2,784]
132+
out = forward model' (input,input)
133+
shape out `shouldBe` [2,10]
134+
it "Fanout" $ do
135+
m0 <- sample (LinearSpec 10 2)
136+
m1 <- sample (LinearSpec 10 3)
137+
let fout = Fanout m0 m1
138+
input = ones' [1,10]
139+
(out0,out1) = forward fout input
140+
shape out0 `shouldBe` [1,2]
141+
shape out1 `shouldBe` [1,3]
142+
it "Split layers" $ do
143+
m0 <- sample mlpSpec
144+
let (h, t) = splitLayers @2 m0
145+
input0 = ones' [1,784]
146+
input1 = ones' [1,64]
147+
output0 = forward h input0
148+
output1 = forward t input1
149+
shape output0 `shouldBe` [1,64]
150+
shape output1 `shouldBe` [1,10]

0 commit comments

Comments
 (0)