1- {-# LANGUAGE TypeOperators#-}
2- {-# LANGUAGE FlexibleInstances#-}
3- {-# LANGUAGE MultiParamTypeClasses#-}
4- {-# LANGUAGE UndecidableInstances#-}
5- {-# LANGUAGE DeriveGeneric#-}
1+ {-# LANGUAGE DataKinds #-}
62{-# LANGUAGE DeriveAnyClass#-}
3+ {-# LANGUAGE DeriveGeneric#-}
74{-# LANGUAGE DuplicateRecordFields#-}
5+ {-# LANGUAGE FlexibleContexts#-}
6+ {-# LANGUAGE FlexibleInstances#-}
7+ {-# LANGUAGE FunctionalDependencies#-}
8+ {-# LANGUAGE GADTs#-}
9+ {-# LANGUAGE MultiParamTypeClasses#-}
10+ {-# LANGUAGE OverloadedRecordDot#-}
11+ {-# LANGUAGE PartialTypeSignatures #-}
12+ {-# LANGUAGE PolyKinds #-}
813{-# LANGUAGE RecordWildCards#-}
914{-# LANGUAGE ScopedTypeVariables#-}
1015{-# LANGUAGE TypeApplications#-}
1116{-# LANGUAGE TypeFamilies#-}
12- {-# LANGUAGE GADTs#-}
13- {-# LANGUAGE OverloadedRecordDot#-}
14- {-# LANGUAGE FlexibleContexts#-}
15- {-# LANGUAGE FunctionalDependencies#-}
1617{-# LANGUAGE TypeFamilyDependencies#-}
17- {-# LANGUAGE PartialTypeSignatures #-}
18+ {-# LANGUAGE TypeOperators#-}
19+ {-# LANGUAGE UndecidableInstances#-}
1820
1921
2022module Torch.Compose where
@@ -23,13 +25,31 @@ import Torch
2325import Torch.NN
2426import Torch.Functional
2527import GHC.Generics hiding ((:+:) )
28+ -- import Data.Void
29+ import Data.HList
30+ import Data.HList (hAppend )
31+ import Data.Kind
32+ import Data.Coerce
33+ import Control.Exception
34+ import System.IO.Unsafe
35+
36+ instance (Randomizable spec0 f0 , Randomizable (HList spec1 ) (HList f1 )) => Randomizable (HList (spec0 ': spec1 )) (HList (f0 ': f1 )) where
37+ sample (HCons s0 s1) = do
38+ f0 <- sample s0
39+ f1 <- sample s1
40+ return (f0 .*. f1)
2641
27- data (:>>: ) a b = (:>>:)
28- { head :: a
29- , tail :: b
30- } deriving (Show , Eq , Generic )
42+ instance Randomizable (HList '[] ) (HList '[] ) where
43+ sample HNil = do
44+ return HNil
45+
46+ instance (HasForward f a b , HasForward (HList g ) b c ) => HasForward (HList (f ': g )) a c where
47+ forward (HCons f g) a = forward g (forward f a)
48+ forwardStoch (HCons f g) a = forwardStoch f a >>= forwardStoch g
3149
32- infixr 5 :>>:
50+ instance HasForward (HList '[] ) a a where
51+ forward _ = id
52+ forwardStoch _ = pure
3353
3454data (://: ) a b = Fanout
3555 { head :: a
@@ -46,16 +66,6 @@ data (:++:) a b = Concat
4666 , tail :: b
4767 } deriving (Show , Eq , Generic )
4868
49- instance (Randomizable spec0 f0 , Randomizable spec1 f1 ) => Randomizable (spec0 :>>: spec1 ) (f0 :>>: f1 ) where
50- sample ((:>>:) s0 s1) = do
51- f0 <- sample s0
52- f1 <- sample s1
53- return ((:>>:) f0 f1)
54-
55- instance (HasForward f a b , HasForward g b c ) => HasForward (f :>>: g ) a c where
56- forward ((:>>:) f g) a = forward g (forward f a)
57- forwardStoch ((:>>:) f g) a = forwardStoch f a >>= forwardStoch g
58-
5969instance (Randomizable spec0 f0 , Randomizable spec1 f1 ) => Randomizable (spec0 ://: spec1 ) (f0 ://: f1 ) where
6070 sample (Fanout s0 s1) = do
6171 f0 <- sample s0
@@ -117,58 +127,78 @@ instance (HasForward a b b) => HasForward (Replicate a) b b where
117127 forwardStoch (Replicate [] ) input = pure input
118128 forwardStoch (Replicate (a: ax)) input = forwardStoch (Replicate ax) =<< forwardStoch a input
119129
120- type family LastLayer x where
121- LastLayer (a :>>: b ) = LastLayer b
122- LastLayer x = x
123-
124- class HasLast x r | x -> r where
125- getLast :: x -> r
126-
127- instance HasLast b r => HasLast (a :>>: b ) r where
128- getLast ((:>>:) _ b) = getLast b
129-
130- instance HasLast a a where
131- getLast = id
132-
133- type family FirstLayer x where
134- FirstLayer (a :>>: b ) = a
135- FirstLayer x = x
136-
137- class HasFirst x r | x -> r where
138- getFirst :: x -> r
139-
140- instance HasFirst a r => HasFirst (a :>>: b ) r where
141- getFirst ((:>>:) a _) = getFirst a
142-
143- instance HasFirst a a where
144- getFirst = id
145-
146130class HasForwardAssoc f a where
147- type ForwardResult f a
131+ type ForwardResult f a :: Type
148132 forwardAssoc :: f -> a -> ForwardResult f a
149133
150- class HasOutputs f a where
151- type Outputs f a
152- toOutputs :: f -> a -> Outputs f a
153-
154- instance (HasForwardAssoc f0 a , HasOutputs f0 a , HasOutputs f1 (ForwardResult f0 a )) => HasOutputs (f0 :>>: f1 ) a where
155- type Outputs (f0 :>>: f1 ) a = Outputs f0 a :>>: Outputs f1 (ForwardResult f0 a )
156- toOutputs ((:>>:) f0 f1) a = (:>>:) (toOutputs f0 a) (toOutputs f1 (forwardAssoc f0 a))
157-
158- class HasInputs f a where
159- type Inputs f a
160- toInputs :: f -> a -> Inputs f a
161-
162- instance (HasForwardAssoc f0 a , HasInputs f0 a , HasInputs f1 (ForwardResult f0 a )) => HasInputs (f0 :>>: f1 ) a where
163- type Inputs (f0 :>>: f1 ) a = Inputs f0 a :>>: Inputs f1 (ForwardResult f0 a )
164- toInputs ((:>>:) f0 f1) a = (:>>:) (toInputs f0 a) (toInputs f1 (forwardAssoc f0 a))
165-
166-
167- class HasOutputShapes f a where
168- type OutputShapes f a
169- toOutputShapes :: f -> a -> OutputShapes f a
170-
171- instance (HasForwardAssoc f0 a , HasOutputShapes f0 a , HasOutputShapes f1 (ForwardResult f0 a )) => HasOutputShapes (f0 :>>: f1 ) a where
172- type OutputShapes (f0 :>>: f1 ) a = OutputShapes f0 a :>>: OutputShapes f1 (ForwardResult f0 a )
173- toOutputShapes ((:>>:) f0 f1) a = (:>>:) (toOutputShapes f0 a) (toOutputShapes f1 (forwardAssoc f0 a))
174-
134+ toHList :: x -> HList '[x ]
135+ toHList x = HCons x HNil
136+
137+ instance (HasForwardAssoc f a ) => HasForwardAssoc f (HList '[a ]) where
138+ type ForwardResult f (HList '[a ]) = HList '[ForwardResult f a ]
139+ forwardAssoc f (HCons a HNil ) = toHList $ forwardAssoc f a
140+
141+ dropLastLayer :: (Coercible a (HList xs1 ), HRevApp xs2 '[x ] xs1 , HRevApp xs2 '[] sx , HRevApp xs1 '[] (x : xs2 ), HRevApp sx '[] xs2 ) => a -> HList sx
142+ dropLastLayer m = hReverse (hDrop (Proxy :: Proxy (HSucc HZero )) (hReverse (coerce m)))
143+
144+ addLastLayer :: HAppend l1 (HList '[e ]) => l1 -> e -> HAppendR l1 (HList '[e ])
145+ addLastLayer a b = a `hAppend` (b .*. HNil )
146+
147+ getLastLayer :: (Coercible a (HList l1 ), HRevApp l1 '[] (e : l )) => a -> e
148+ getLastLayer a = hLast (coerce a)
149+
150+ hScanl :: forall f z ls xs1 sx xs2 . (HScanr f z ls xs1 , HRevApp xs1 '[] sx , HRevApp sx '[] xs1 , HRevApp xs2 '[] ls , HRevApp ls '[] xs2 ) => f -> z -> HList xs2 -> HList sx
151+ hScanl a b c = hReverse $ hScanr a b (hReverse c)
152+
153+ safeEval :: forall a . a -> Maybe a
154+ safeEval x = unsafePerformIO $ do
155+ result <- try (evaluate @ a x) :: IO (Either SomeException a )
156+ case result of
157+ Left _ -> return Nothing
158+ Right v -> return (Just v)
159+
160+ type family ForwardMap (xs :: [* ]) (a :: * ) :: [* ] where
161+ ForwardMap '[] _ = '[]
162+ ForwardMap (x ': xs ) a = ForwardResult x a ': ForwardMap xs (ForwardResult x a )
163+
164+ class Outputs xs input where
165+ toOutputs' :: HList xs -> input -> HList (ForwardMap xs input )
166+
167+ instance HasForwardAssoc x a => HasForwardAssoc x (Maybe a ) where
168+ type ForwardResult x (Maybe a ) = Maybe (ForwardResult x a )
169+ forwardAssoc x (Just a) = Just $ forwardAssoc x a
170+ forwardAssoc x Nothing = Nothing
171+
172+
173+ instance (HasForwardAssoc x a , Outputs xs (ForwardResult x a )) => Outputs (x ': xs ) a where
174+ toOutputs' (HCons x xs) a =
175+ let out = forwardAssoc x a
176+ in HCons out $ toOutputs' xs out
177+
178+ instance Outputs '[] a where
179+ toOutputs' _ _ = HNil
180+
181+ toOutputs ::
182+ (Coercible a (HList xs ),
183+ Outputs xs input
184+ ) =>
185+ a -> input -> HList (ForwardMap xs input )
186+ toOutputs f = toOutputs' (coerce f)
187+
188+ toOutputShapes ::
189+ (Coercible a (HList xs ),
190+ HMapAux HList (Tensor -> [Int ]) (ForwardMap xs input ) b ,
191+ SameLength' b (ForwardMap xs input ),
192+ SameLength' (ForwardMap xs input ) b , Outputs xs input
193+ ) =>
194+ a -> input -> HList b
195+ toOutputShapes f a = hMap shape (toOutputs f a)
196+
197+ toMaybeOutputShapes ::
198+ (Coercible a (HList xs ),
199+ HMapAux HList (Tensor -> Maybe [Int ]) (ForwardMap xs input ) b ,
200+ SameLength' b (ForwardMap xs input ),
201+ SameLength' (ForwardMap xs input ) b , Outputs xs input
202+ ) =>
203+ a -> input -> HList b
204+ toMaybeOutputShapes f a = hMap (safeEval . shape) (toOutputs f a)
0 commit comments