Skip to content

Commit 278bab6

Browse files
fix ViT model output + rewrite attention layer + adapt torchvision script (#230)
* port ViT weights * identify the problem * fix ViT model * LayerNormV2 * cleanup * address comments
1 parent 0425c72 commit 278bab6

File tree

9 files changed

+140
-60
lines changed

9 files changed

+140
-60
lines changed

scripts/CondaPkg.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
channels = ["nvidia", "torch"]
1+
channels = ["pytorch"]
22

33
[deps]
4-
pytorch = ""
5-
torchvision = ""
4+
pytorch = ">=2,<3"
5+
torchvision = ">=0.15"

scripts/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
3+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
34
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5+
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
46
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
57
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
68
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

scripts/port_torchvision.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,31 @@ const tvmodels = pyimport("torchvision.models")
77

88
# name, weight, jlconstructor, pyconstructor
99
model_list = [
10-
("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(weights=weights)),
11-
("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(weights=weights)),
12-
("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(weights=weights)),
13-
("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(weights=weights)),
14-
("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(weights=weights)),
15-
("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(weights=weights)),
16-
("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(weights=weights)),
17-
("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(weights=weights)),
18-
("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(weights=weights)),
19-
("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(weights=weights)),
20-
("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(weights=weights)),
21-
("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(weights=weights)),
22-
("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(weights=weights)),
23-
("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(weights=weights)),
24-
("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(weights=weights)),
25-
("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(weights=weights)),
26-
("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(weights=weights)),
27-
("wide_resnet50_2", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(weights=weights)),
28-
("wide_resnet50_2", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(weights=weights)),
29-
("wide_resnet101_2", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(weights=weights)),
30-
("wide_resnet101_2", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(weights=weights)),
31-
10+
("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(; weights)),
11+
("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(; weights)),
12+
("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(; weights)),
13+
("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(; weights)),
14+
("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(; weights)),
15+
("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(; weights)),
16+
("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)),
17+
("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)),
18+
("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)),
19+
("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)),
20+
("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)),
21+
("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)),
22+
("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)),
23+
("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)),
24+
("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)),
25+
("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(; weights)),
26+
("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)),
27+
("wide_resnet50_2", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)),
28+
("wide_resnet50_2", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)),
29+
("wide_resnet101_2", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; ; weights)),
30+
("wide_resnet101_2", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)),
31+
("vit_b_16", "IMAGENET1K_V1", () -> ViT(:base, imsize=(224,224), qkv_bias=true), weights -> tvmodels.vit_b_16(; weights)),
3232
## NOT MATCHING BELOW
33-
# ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(weights=weights)),
34-
# ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(weights=weights)),
33+
# ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)),
34+
# ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)),
3535
]
3636

3737

@@ -44,4 +44,3 @@ for (name, weights, jlconstructor, pyconstructor) in model_list
4444
BSON.@save joinpath(@__DIR__, "$(name)_$weights.bson") model=jlmodel
4545
println("Saved $(name)_$weights.bson")
4646
end
47-

scripts/pytorch2flux.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,41 @@ function _list_state(node::Dense, channel, prefix)
7575
end
7676
end
7777

78+
function _list_state(node::Metalhead.Layers.ClassTokens, channel, prefix)
79+
put!(channel, (prefix * ".classtoken", node.token))
80+
end
81+
82+
function _list_state(node::Metalhead.Layers.ViPosEmbedding, channel, prefix)
83+
put!(channel, (prefix * ".posembedding", node.vectors))
84+
end
85+
86+
function _list_state(node::LayerNorm, channel, prefix)
87+
put!(channel, (prefix * ".layernorm_scale", node.diag.scale))
88+
put!(channel, (prefix * ".layernorm_bias", node.diag.bias))
89+
end
90+
91+
function _list_state(node::Metalhead.Layers.LayerNormV2, channel, prefix)
92+
put!(channel, (prefix * ".layernorm_scale", node.diag.scale))
93+
put!(channel, (prefix * ".layernorm_bias", node.diag.bias))
94+
end
95+
96+
function _list_state(node::Metalhead.Layers.MultiHeadSelfAttention, channel, prefix)
97+
_list_state(node.qkv_layer, channel, prefix * ".qkv")
98+
_list_state(node.projection, channel, prefix * ".proj")
99+
end
100+
78101
function _list_state(node::Chain, channel, prefix)
79102
for (i, n) in enumerate(node.layers)
80103
_list_state(n, channel, prefix * ".layers[$i]")
81104
end
82105
end
83106

107+
function _list_state(node::SkipConnection, channel, prefix)
108+
for (i, n) in enumerate(node.layers)
109+
_list_state(n, channel, prefix * ".layers[$i]")
110+
end
111+
end
112+
84113
function _list_state(node::Parallel, channel, prefix)
85114
# reverse to match PyTorch order, see https://github.com/FluxML/Metalhead.jl/issues/228
86115
for (i, n) in enumerate(reverse(node.layers))
@@ -102,6 +131,18 @@ function pytorch2flux!(jlmodel, pymodel; verb=false)
102131
state_dict = pymodel.state_dict()
103132
pystate = OrderedDict((py2jl(k), th2jl(v)) for (k, v) in state_dict.items() if
104133
!occursin("num_batches_tracked", py2jl(k)))
134+
135+
jlkeys = collect(keys(jlstate))
136+
pykeys = collect(keys(pystate))
137+
138+
## handle class_token since it is not in the same order
139+
jl_k = findfirst(k -> occursin("classtoken", k), jlkeys)
140+
py_k = findfirst(k -> occursin("class_token", k), pykeys)
141+
if jl_k !== nothing && py_k !== nothing
142+
jlstate[jlkeys[jl_k]] .= pystate[pykeys[py_k]]
143+
delete!(pystate, pykeys[py_k])
144+
delete!(jlstate, jlkeys[jl_k])
145+
end
105146

106147
for ((flux_key, flux_param), (pytorch_key, pytorch_param)) in zip(jlstate, pystate)
107148
# @show flux_key size(flux_param) pytorch_key size(pytorch_param)

scripts/utils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ function np2jl(x::Py)
2525
end
2626

2727
function th2jl(x::Py)
28-
x_jl = pyconvert(Array, x)
28+
x_jl = pyconvert(Array, x.detach().numpy())
2929
x_jl = permutedims(x_jl, ndims(x_jl):-1:1)
3030
return x_jl
3131
end
3232

3333
py2jl(x::Py) = pyconvert(Any, x)
34+
35+
36+
## SAVE STATE
37+
using Functors
38+
state_arrays(x) = fmapstructure(x -> x isa AbstractArray ? x : missing, x)

src/layers/Layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import Flux.testmode!
1616
include("../utilities.jl")
1717

1818
include("attention.jl")
19-
export MHAttention
19+
export MultiHeadSelfAttention
2020

2121
include("conv.jl")
2222
export conv_norm, basic_conv_bn, dwsep_conv_norm

src/layers/attention.jl

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
2+
MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
33
attn_dropout_prob = 0., proj_dropout_prob = 0.)
44
55
Multi-head self-attention layer.
@@ -12,39 +12,27 @@ Multi-head self-attention layer.
1212
- `attn_dropout_prob`: dropout probability after the self-attention layer
1313
- `proj_dropout_prob`: dropout probability after the projection layer
1414
"""
15-
struct MHAttention{P, Q, R}
15+
struct MultiHeadSelfAttention{P, Q, R}
1616
nheads::Int
1717
qkv_layer::P
1818
attn_drop::Q
1919
projection::R
2020
end
21-
@functor MHAttention
21+
@functor MultiHeadSelfAttention
2222

23-
function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
23+
function MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
2424
attn_dropout_prob = 0.0, proj_dropout_prob = 0.0)
2525
@assert planes % nheads==0 "planes should be divisible by nheads"
2626
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
2727
attn_drop = Dropout(attn_dropout_prob)
2828
proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob))
29-
return MHAttention(nheads, qkv_layer, attn_drop, proj)
29+
return MultiHeadSelfAttention(nheads, qkv_layer, attn_drop, proj)
3030
end
3131

32-
function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
33-
nfeatures, seq_len, batch_size = size(x)
34-
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
35-
qkv = m.qkv_layer(x_reshaped)
36-
qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size)
37-
query, key, value = chunk(qkv_reshaped, 3; dims = 4)
38-
scale = convert(T, sqrt(size(query, 1) / m.nheads))
39-
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
40-
seq_len * batch_size)
41-
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
42-
m.nheads, seq_len * batch_size)
43-
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
44-
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
45-
m.nheads, seq_len * batch_size)
46-
pre_projection = reshape(batched_mul(attention, value_reshaped),
47-
(nfeatures, seq_len, batch_size))
48-
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))
49-
return reshape(y, :, seq_len, batch_size)
32+
function (m::MultiHeadSelfAttention)(x::AbstractArray{<:Number, 3})
33+
qkv = m.qkv_layer(x)
34+
q, k, v = chunk(qkv, 3, dims = 1)
35+
y, α = NNlib.dot_product_attention(q, k, v; m.nheads, fdrop = m.attn_drop)
36+
y = m.projection(y)
37+
return y
5038
end

src/layers/normalise.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,45 @@ function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6)
2424
end
2525

2626
(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))
27+
28+
29+
"""
30+
LayerNormV2(size..., λ=identity; affine=true, eps=1f-5)
31+
32+
Same as Flux's LayerNorm but eps is added before taking the square root in the denominator.
33+
Therefore, LayerNormV2 matches pytorch's LayerNorm.
34+
"""
35+
struct LayerNormV2{F,D,T,N}
36+
λ::F
37+
diag::D
38+
ϵ::T
39+
size::NTuple{N,Int}
40+
affine::Bool
41+
end
42+
43+
function LayerNormV2(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5)
44+
diag = affine ? Flux.Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
45+
return LayerNormV2(λ, diag, eps, size, affine)
46+
end
47+
LayerNormV2(size::Integer...; kw...) = LayerNormV2(Int.(size); kw...)
48+
LayerNormV2(size_act...; kw...) = LayerNormV2(Int.(size_act[1:end-1]), size_act[end]; kw...)
49+
50+
@functor LayerNormV2
51+
52+
function (a::LayerNormV2)(x::AbstractArray)
53+
eps = convert(float(eltype(x)), a.ϵ) # avoids promotion for Float16 data, but should ε chage too?
54+
a.diag(_normalise(x; dims=1:length(a.size), eps))
55+
end
56+
57+
function Base.show(io::IO, l::LayerNormV2)
58+
print(io, "LayerNormV2(", join(l.size, ", "))
59+
l.λ === identity || print(io, ", ", l.λ)
60+
Flux.hasaffine(l) || print(io, ", affine=false")
61+
print(io, ")")
62+
end
63+
64+
@inline function _normalise(x::AbstractArray; dims=ndims(x), eps=Flux.ofeltype(x, 1e-5))
65+
μ = mean(x, dims=dims)
66+
σ² = var(x, dims=dims, mean=μ, corrected=false)
67+
return @. (x - μ) / sqrt(σ² + eps)
68+
end

src/vit-based/vit.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ Transformer as used in the base ViT architecture.
1313
- `dropout_prob`: dropout probability
1414
"""
1515
function transformer_encoder(planes::Integer, depth::Integer, nheads::Integer;
16-
mlp_ratio = 4.0, dropout_prob = 0.0)
16+
mlp_ratio = 4.0, dropout_prob = 0.0, qkv_bias=false)
1717
layers = [Chain(SkipConnection(prenorm(planes,
18-
MHAttention(planes, nheads;
18+
MultiHeadSelfAttention(planes, nheads;
19+
qkv_bias,
1920
attn_dropout_prob = dropout_prob,
2021
proj_dropout_prob = dropout_prob)),
2122
+),
@@ -51,7 +52,8 @@ Creates a Vision Transformer (ViT) model.
5152
function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3,
5253
patch_size::Dims{2} = (16, 16), embedplanes::Integer = 768,
5354
depth::Integer = 6, nheads::Integer = 16, mlp_ratio = 4.0, dropout_prob = 0.1,
54-
emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000)
55+
emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000,
56+
qkv_bias = false)
5557
@assert pool in [:class, :mean]
5658
"Pool type must be either `:class` (class token) or `:mean` (mean pooling)"
5759
npatches = prod(imsize patch_size)
@@ -60,9 +62,9 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3,
6062
ViPosEmbedding(embedplanes, npatches + 1),
6163
Dropout(emb_dropout_prob),
6264
transformer_encoder(embedplanes, depth, nheads; mlp_ratio,
63-
dropout_prob),
65+
dropout_prob, qkv_bias),
6466
pool === :class ? x -> x[:, 1, :] : seconddimmean),
65-
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
67+
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses)))
6668
end
6769

6870
const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
@@ -100,9 +102,10 @@ end
100102
@functor ViT
101103

102104
function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16),
103-
pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
105+
pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000,
106+
qkv_bias=false)
104107
_checkconfig(config, keys(VIT_CONFIGS))
105-
layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...)
108+
layers = vit(imsize; inchannels, patch_size, nclasses, qkv_bias, VIT_CONFIGS[config]...)
106109
if pretrain
107110
loadpretrain!(layers, string("vit", config))
108111
end

0 commit comments

Comments
 (0)