Skip to content

Commit 94e84d1

Browse files
committed
Merge branch 'master' into stable
2 parents 48924ef + cac5625 commit 94e84d1

File tree

11 files changed

+38
-22
lines changed

11 files changed

+38
-22
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ os:
66
- osx
77
julia:
88
- 0.5
9-
- nightly
9+
# - nightly 0.6 supports depends on #170
1010

1111
# dependent apt packages
1212
addons:

deps/build.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ if !libmxnet_detected
3131
base_url = "https://github.com/dmlc/mxnet/releases/download/20160531/20160531_win10_x64_cpu.7z"
3232
if libmxnet_curr_ver == "master"
3333
# download_cmd uses powershell 2, but we need powershell 3 to do this
34-
ps_wget(url, file) = run(`powershell -NoProfile -Command "wget \"$url\" -o \"$file\""`)
35-
ps_wget("https://api.github.com/repos/yajiedesign/mxnet/releases/latest", "mxnet.json")
34+
run(`powershell -NoProfile -Command Invoke-WebRequest -Uri "https://api.github.com/repos/yajiedesign/mxnet/releases/latest" -OutFile "mxnet.json"`)
3635
curr_win = JSON.parsefile("mxnet.json")["tag_name"]
3736
info("Can't use MXNet master on Windows, using latest binaries from $curr_win.")
3837
end

examples/char-lstm/seq-data.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using MXNet
55
function build_vocabulary(corpus_fn::AbstractString, vocab_fn::AbstractString; max_vocab=10000)
66
if isfile(vocab_fn)
77
info("Vocabulary already exists, reusing $vocab_fn...")
8-
vocab = Dict{Char,Int}([w => i for (i,w) in enumerate(readall(vocab_fn))])
8+
vocab = Dict{Char,Int}([w => i for (i,w) in enumerate(readstring(vocab_fn))])
99
else
1010
# count symbol frequency
1111
dict = Dict{Char,Int}()

examples/char-lstm/train.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ lstm = LSTM(LSTM_N_LAYER, SEQ_LENGTH, DIM_HIDDEN, DIM_EMBED,
1414

1515
#--data
1616
# load data
17-
text_all = readall(INPUT_FILE)
17+
text_all = readstring(INPUT_FILE)
1818
len_train = round(Int, length(text_all)*DATA_TR_RATIO)
1919
text_tr = text_all[1:len_train]
2020
text_val = text_all[len_train+1:end]

examples/mnist/mlp-test.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ module MNISTTest
55
using MXNet
66
using Base.Test
77

8+
include("mnist-data.jl")
9+
810
function get_mnist_mlp()
911
mlp = @mx.chain mx.Variable(:data) =>
1012
mx.FullyConnected(name=:fc1, num_hidden=128) =>
@@ -17,7 +19,6 @@ function get_mnist_mlp()
1719
end
1820

1921
function get_mnist_data(batch_size=100)
20-
include("mnist-data.jl")
2122
return get_mnist_providers(batch_size)
2223
end
2324

@@ -40,7 +41,7 @@ function mnist_fit_and_predict(optimizer, initializer, n_epoch)
4041
end
4142
mlp_load = mx.load("$cp_prefix-symbol.json", mx.SymbolicNode)
4243
@test mx.to_json(mlp_load) == mx.to_json(mlp)
43-
mlp_load = mx.from_json(readall("$cp_prefix-symbol.json"), mx.SymbolicNode)
44+
mlp_load = mx.from_json(readstring("$cp_prefix-symbol.json"), mx.SymbolicNode)
4445
@test mx.to_json(mlp_load) == mx.to_json(mlp)
4546

4647
#--------------------------------------------------------------------------------

src/MXNet.jl

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ using Compat
1111
import Compat.String
1212
import Compat.view
1313

14+
if VERSION >= v"0.6.0-dev.1024"
15+
import Base.Iterators: filter
16+
end
17+
1418
using Formatting
1519

1620
# Functions from base that we can safely extend and that are defined by libmxnet.

src/model.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
389389
end
390390
end
391391

392-
train_execs = Array(Executor, num_dev)
392+
train_execs = Array{Executor}(num_dev)
393393
for i = 1:num_dev
394394
data_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_data(data)))
395395
label_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_label(data)))

src/optimizers/sgd.jl

+14-11
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,21 @@ function create_state(self :: SGD, index :: Int, weight :: NDArray)
4949
end
5050
end
5151

52-
function update(self :: SGD, index :: Int, weight :: NDArray, grad :: NDArray, state :: Union{Void, NDArray})
52+
function update(self :: SGD, index :: Int, weight :: NDArray, grad :: NDArray, state :: Void)
5353
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
5454
grad = normalized_gradient(self.opts, self.state, weight, grad)
55+
56+
@inplace weight += -lr * grad
57+
end
5558

56-
if isa(state, Void)
57-
# vanilla SGD, without momentum
58-
@inplace weight += -lr * grad
59-
else
60-
mom = state :: NDArray
61-
coef = get_momentum(self.opts.momentum_scheduler, self.state)
62-
@inplace mom .*= coef
63-
@inplace mom .+= -lr * grad
64-
@inplace weight .+= mom
65-
end
59+
# update with momentum
60+
function update(self :: SGD, index :: Int, weight :: NDArray, grad :: NDArray, state :: NDArray)
61+
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
62+
grad = normalized_gradient(self.opts, self.state, weight, grad)
63+
64+
mom = state :: NDArray
65+
coef = get_momentum(self.opts.momentum_scheduler, self.state)
66+
@inplace mom .*= coef
67+
@inplace mom .+= -lr * grad
68+
@inplace weight .+= mom
6669
end

src/symbolic-node.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,8 @@ function _define_atomic_symbol_creator(name :: String)
627627

628628
$(if key_narg != ""
629629
quote
630-
if !in(Symbol($key_narg), param_keys)
631-
push!(param_keys, Symbol($key_narg))
630+
if !in($key_narg, param_keys)
631+
push!(param_keys, $key_narg)
632632
push!(param_vals, string(length(args)))
633633
end
634634
end

test/unittest/kvstore.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function test_aggregator()
6262

6363
for vv in vals
6464
for v in vv
65-
@test maximum(abs(copy(v)) - 2num_devs) == 0
65+
@test maximum(abs.(copy(v)) - 2 * num_devs) == 0
6666
end
6767
end
6868
end

test/unittest/symbolic-node.jl

+9
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,14 @@ function test_dot()
130130
@test reldiff(ret, 2*ones(100, 200)) < 1e-6
131131
end
132132

133+
function test_misc()
134+
info("SymbolicNode::Miscellaneous")
135+
# Test for #189
136+
a = mx.Variable("a")
137+
b = mx.Variable("b")
138+
symb = mx.ElementWiseSum(a,b)
139+
end
140+
133141
################################################################################
134142
# Run tests
135143
################################################################################
@@ -143,6 +151,7 @@ end
143151
test_attrs()
144152
test_functions()
145153
test_dot()
154+
test_misc()
146155
end
147156

148157
end

0 commit comments

Comments
 (0)