Skip to content

Commit c0ed0b8

Browse files
Update
1 parent eff89e2 commit c0ed0b8

File tree

3 files changed

+42
-17
lines changed

3 files changed

+42
-17
lines changed

README.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
# Hasktorch Compose
22

3-
In hasktorch, model specifications, values, and inference are defined separately. there are cases combining commonly used models. For example, we might want to connect three linear layers, or add one linear layer to an existing model.
4-
This repo provides a library to easily compose existing models.
5-
We plan to provide both an untyped API and a typed API, but we will prioritize the development of the untyped API.
3+
In Hasktorch, model specifications, values, and inference are defined separately. This often necessitates combining commonly used models. For example, three linear layers may need to be connected in sequence, or a new linear layer could be added to an existing model. Hasktorch Compose provides a straightforward way to compose such models.
64

7-
This is an experimental library developed based on [hasktorch-skeleton](https://github.com/hasktorch/hasktorch-skeleton).
5+
In addition to simple model composition, this library aims to support extracting parts of models and sharing parameters between different models, such as ControlNet and RoLa. Both an untyped API and a typed API are planned, with initial development focused on the untyped API.
86

9-
List of planned features:
7+
Hasktorch Compose is an experimental library built on top of [hasktorch-skeleton](https://github.com/hasktorch/hasktorch-skeleton).
108

9+
**Planned Features:**
1110
- [x] Sequential
1211
- [ ] Extract layer
1312
- [ ] Test for each layer
1413
- [ ] Overlay layer
15-
- [ ] Concatenate layer
14+
- [x] Concatenate layer
1615

1716
# Examples
1817

@@ -30,11 +29,11 @@ mlpSpec =
3029
ReluSpec :>>:
3130
LinearSpec 64 32) :>>:
3231
ReluSpec :>>:
33-
LinearSpec 32 10
32+
LinearSpec 32 10 :>>:
33+
LogSoftMaxSpec
3434

3535
mlp :: (Randomizable MLPSpec MLP, HasForward MLP Tensor Tensor) => MLP -> Tensor -> Tensor
36-
mlp model input =
37-
logSoftmax (Dim 1) $ forward model input
36+
mlp model input = forward model input
3837
```
3938

4039
## Extract layer

src/Torch/Compose/NN.hs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ instance Randomizable BatchNorm2dSpec BatchNorm2d where
117117
runningVar <- newMutableTensor $ ones' [channelSize]
118118
pure BatchNorm2d{..}
119119

120+
data LogSoftMaxSpec = LogSoftMaxSpec deriving (Generic, Show, Eq)
121+
data LogSoftMax = LogSoftMax deriving (Generic, Parameterized, Show, Eq)
122+
instance Randomizable LogSoftMaxSpec LogSoftMax where
123+
sample _ = pure LogSoftMax
124+
125+
instance HasForward LogSoftMax Tensor Tensor where
126+
forward _ = logSoftmax (Dim 1)
127+
forwardStoch _ i = pure $ logSoftmax (Dim 1) i
128+
129+
120130
instance HasOutputs Linear Tensor where
121131
type Outputs Linear Tensor = Tensor
122132
toOutputs = forward
@@ -141,6 +151,19 @@ instance HasOutputShapes Relu Tensor where
141151
type OutputShapes Relu Tensor = [Int]
142152
toOutputShapes model a = shape $ forward model a
143153

154+
instance HasOutputs LogSoftMax Tensor where
155+
type Outputs LogSoftMax Tensor = Tensor
156+
toOutputs = forward
157+
158+
instance HasInputs LogSoftMax Tensor where
159+
type Inputs LogSoftMax Tensor = Tensor
160+
toInputs _ a = a
161+
162+
instance HasOutputShapes LogSoftMax Tensor where
163+
type OutputShapes LogSoftMax Tensor = [Int]
164+
toOutputShapes model a = shape $ forward model a
165+
166+
144167
instance HasForwardAssoc Linear Tensor where
145168
type ForwardResult Linear Tensor = Tensor
146169
forwardAssoc = forward
@@ -149,3 +172,7 @@ instance HasForwardAssoc Relu Tensor where
149172
type ForwardResult Relu Tensor = Tensor
150173
forwardAssoc = forward
151174

175+
instance HasForwardAssoc LogSoftMax Tensor where
176+
type ForwardResult LogSoftMax Tensor = Tensor
177+
forwardAssoc = forward
178+

test/Spec.hs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@ import Torch.Compose.NN
2424
import Torch
2525
import GHC.Generics hiding ((:+:))
2626

27-
newtype MLPSpec = MLPSpec (LinearSpec :>>: (ReluSpec :>>: (LinearSpec :>>: (ReluSpec :>>: LinearSpec)))) deriving (Generic)
28-
newtype MLP = MLP (Linear :>>: (Relu :>>: (Linear :>>: (Relu :>>: Linear)))) deriving (Generic)
27+
newtype MLPSpec = MLPSpec (LinearSpec :>>: ReluSpec :>>: LinearSpec :>>: ReluSpec :>>: LinearSpec :>>: LogSoftMaxSpec) deriving (Generic)
28+
newtype MLP = MLP (Linear :>>: Relu :>>: Linear :>>: Relu :>>: Linear :>>: LogSoftMax) deriving (Generic)
2929

3030
mlpSpec :: MLPSpec
3131
mlpSpec = MLPSpec $
3232
LinearSpec 784 64 :>>:
3333
ReluSpec :>>:
3434
LinearSpec 64 32 :>>:
3535
ReluSpec :>>:
36-
LinearSpec 32 10
36+
LinearSpec 32 10 :>>:
37+
LogSoftMaxSpec
3738

3839
instance HasForward MLP Tensor Tensor where
3940
forward (MLP model) = forward model
@@ -43,8 +44,7 @@ instance Randomizable MLPSpec MLP where
4344
sample (MLPSpec spec) = MLP <$> sample spec
4445

4546
mlp :: MLP -> Tensor -> Tensor
46-
mlp model input =
47-
logSoftmax (Dim 1) $ forward model input
47+
mlp model input = forward model input
4848

4949
main :: IO ()
5050
main = hspec $ do
@@ -58,10 +58,9 @@ main = hspec $ do
5858
it "Extract the last layer" $ do
5959
(MLP (model :: a)) <- sample mlpSpec
6060
let layer = getLast model :: LastLayer a
61-
shape layer.weight.toDependent `shouldBe` [10,32]
61+
layer `shouldBe` LogSoftMax
6262
it "Extract all output shapes" $ do
6363
(MLP (model :: a)) <- sample mlpSpec
6464
let out = toOutputShapes model (ones' [2,784])
65-
exp = [2,64] :>>: [2,64] :>>: [2,32] :>>: [2,32] :>>: [2,10]
65+
exp = [2,64] :>>: [2,64] :>>: [2,32] :>>: [2,32] :>>: [2,10] :>>: [2,10]
6666
out `shouldBe` exp
67-

0 commit comments

Comments
 (0)