2020{-# LANGUAGE TypeFamilyDependencies#-}
2121{-# LANGUAGE TypeOperators#-}
2222{-# LANGUAGE UndecidableInstances#-}
23+ {-# LANGUAGE AllowAmbiguousTypes#-}
24+
2325
2426module Torch.Compose where
2527
@@ -31,6 +33,7 @@ import Data.Coerce
3133import Control.Exception
3234import System.IO.Unsafe
3335import qualified Language.Haskell.TH as TH
36+ import GHC.TypeLits
3437-- import qualified Language.Haskell.TH.Syntax as TH
3538
3639
@@ -144,15 +147,34 @@ instance (HasForwardAssoc f a) => HasForwardAssoc f (HList '[a]) where
144147 forwardAssoc f (HCons a HNil ) = toHList $ forwardAssoc f a
145148 forwardStochAssoc f (HCons a HNil ) = toHList <$> forwardStochAssoc f a
146149
150+ type family ToHNat (x :: Nat ) :: HNat where
151+ ToHNat 0 = HZero
152+ ToHNat x = HSucc (ToHNat ( x - 1 ))
153+
154+ dropLayers :: forall num a ys xs . (KnownNat num , Coercible a (HList xs ), HDrop (ToHNat num ) xs ys , HLengthGe xs (ToHNat num )) => a -> HList ys
155+ dropLayers m = hDrop (Proxy :: Proxy (ToHNat num )) (coerce m)
156+
157+ takeLayers :: forall num a ys xs . (KnownNat num , Coercible a (HList xs ), HTake (ToHNat num ) xs ys , SameLength' (HReplicateR (ToHNat num ) () ) ys , HLengthEq1 ys (ToHNat num ), HLengthEq2 ys (ToHNat num ), HLengthGe xs (ToHNat num )) => a -> HList ys
158+ takeLayers m = hTake (Proxy :: Proxy (ToHNat num )) (coerce m)
159+
160+ splitLayers :: forall num a xs ys xsys . (KnownNat num , Coercible a (HList xsys ), HSplitAt1 '[] (ToHNat num ) xsys xs ys , HAppendList1 xs ys xsys , SameLength' (HReplicateR (ToHNat num ) () ) xs , HLengthEq1 xs (ToHNat num ), HLengthEq2 xs (ToHNat num )) => a -> (HList xs , HList ys )
161+ splitLayers m = hSplitAt (Proxy :: Proxy (ToHNat num )) (coerce m)
162+
147163dropLastLayer :: (Coercible a (HList xs1 ), HRevApp xs2 '[x ] xs1 , HRevApp xs2 '[] sx , HRevApp xs1 '[] (x : xs2 ), HRevApp sx '[] xs2 ) => a -> HList sx
148164dropLastLayer m = hReverse (hDrop (Proxy :: Proxy (HSucc HZero )) (hReverse (coerce m)))
149165
150166addLastLayer :: HAppend l1 (HList '[e ]) => l1 -> e -> HAppendR l1 (HList '[e ])
151167addLastLayer a b = a `hAppend` (b .*. HNil )
152168
169+ dropFirstLayer :: Coercible a (HList (x : ys )) => a -> HList ys
170+ dropFirstLayer m = hDrop (Proxy :: Proxy (HSucc HZero )) (coerce m)
171+
153172getLastLayer :: (Coercible a (HList l1 ), HRevApp l1 '[] (e : l )) => a -> e
154173getLastLayer a = hLast (coerce a)
155174
175+ getFirstLayer :: Coercible a (HList (e : l )) => a -> e
176+ getFirstLayer a = hHead (coerce a)
177+
156178hScanl :: forall f z ls xs1 sx xs2 . (HScanr f z ls xs1 , HRevApp xs1 '[] sx , HRevApp sx '[] xs1 , HRevApp xs2 '[] ls , HRevApp ls '[] xs2 ) => f -> z -> HList xs2 -> HList sx
157179hScanl a b c = hReverse $ hScanr a b (hReverse c)
158180
@@ -259,12 +281,5 @@ instanceForwardAssoc model input output =
259281 |]
260282
261283instanceForwardAssocs :: [TH. Q TH. Type ] -> TH. Q TH. Type -> TH. Q TH. Type -> TH. DecsQ
262- instanceForwardAssocs models input output = do
263- decs <- forM models $ \ model ->
264- [d |
265- instance HasForwardAssoc $model $input where
266- type ForwardResult $model $input = $output
267- forwardAssoc = forward
268- forwardStochAssoc = forwardStoch
269- |]
270- return $ concat decs
284+ instanceForwardAssocs models input output =
285+ concat <$> forM models (\ model -> instanceForwardAssoc model input output)
0 commit comments