@@ -49,18 +49,21 @@ function create_state(self :: SGD, index :: Int, weight :: NDArray)
49
49
end
50
50
end
51
51
52
- function update (self :: SGD , index :: Int , weight :: NDArray , grad :: NDArray , state :: Union{ Void, NDArray} )
52
+ function update (self :: SGD , index :: Int , weight :: NDArray , grad :: NDArray , state :: Void )
53
53
lr = get_learning_rate (self. opts. lr_scheduler, self. state)
54
54
grad = normalized_gradient (self. opts, self. state, weight, grad)
55
+
56
+ @inplace weight += - lr * grad
57
+ end
55
58
56
- if isa (state, Void)
57
- # vanilla SGD, without momentum
58
- @inplace weight += - lr * grad
59
- else
60
- mom = state :: NDArray
61
- coef = get_momentum (self . opts . momentum_scheduler, self . state)
62
- @inplace mom .*= coef
63
- @inplace mom .+ = - lr * grad
64
- @inplace weight .+ = mom
65
- end
59
+ # update with momentum
60
+ function update (self :: SGD , index :: Int , weight :: NDArray , grad :: NDArray , state :: NDArray )
61
+ lr = get_learning_rate (self . opts . lr_scheduler, self . state)
62
+ grad = normalized_gradient (self . opts, self . state, weight, grad)
63
+
64
+ mom = state :: NDArray
65
+ coef = get_momentum (self . opts . momentum_scheduler, self . state)
66
+ @inplace mom .*= coef
67
+ @inplace mom .+ = - lr * grad
68
+ @inplace weight .+ = mom
66
69
end
0 commit comments