-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathcifar10.jl
227 lines (196 loc) · 8.16 KB
/
cifar10.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
function __init__cifar10()
DEPNAME = "CIFAR10"
register(DataDep(
DEPNAME,
"""
Dataset: The CIFAR-10 dataset
Authors: Alex Krizhevsky, Vinod Nair, Geoffrey Hinton
Website: https://www.cs.toronto.edu/~kriz/cifar.html
Reference: https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
[Krizhevsky, 2009]
Alex Krizhevsky.
"Learning Multiple Layers of Features from Tiny Images",
Tech Report, 2009.
The CIFAR-10 dataset is a labeled subsets of the 80
million tiny images dataset. It consists of 60000
32x32 colour images in 10 classes, with 6000 images
per class.
The compressed archive file that contains the
complete dataset is available for download at the
offical website linked above; specifically the binary
version for C programs. Note that using the data
responsibly and respecting copyright remains your
responsibility. The authors of CIFAR-10 aren't really
explicit about any terms of use, so please read the
website to make sure you want to download the
dataset.
""",
"https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
"c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd",
post_fetch_method = DataDeps.unpack
))
end
"""
CIFAR10(; Tx=Float32, split=:train, dir=nothing)
CIFAR10([Tx, split])
The CIFAR10 dataset is a labeled subsets of the 80
million tiny images dataset. It consists of 60000
32x32 colour images in 10 classes, with 6000 images
per class.
# Arguments
$ARGUMENTS_SUPERVISED_ARRAY
- `split`: selects the data partition. Can take the values `:train:` or `:test`.
# Fields
$FIELDS_SUPERVISED_ARRAY
- `split`.
# Methods
$METHODS_SUPERVISED_ARRAY
- [`convert2image`](@ref) converts features to `RGB` images.
# Examples
```julia-repl
julia> using MLDatasets: CIFAR10
julia> dataset = CIFAR10()
CIFAR10:
metadata => Dict{String, Any} with 2 entries
split => :train
features => 32×32×3×50000 Array{Float32, 4}
targets => 50000-element Vector{Int64}
julia> dataset[1:5].targets
5-element Vector{Int64}:
6
9
9
4
1
julia> X, y = dataset[:];
julia> dataset = CIFAR10(Tx=Float64, split=:test)
CIFAR10:
metadata => Dict{String, Any} with 2 entries
split => :test
features => 32×32×3×10000 Array{Float64, 4}
targets => 10000-element Vector{Int64}
julia> dataset.metadata
Dict{String, Any} with 2 entries:
"n_observations" => 10000
"class_names" => ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
```
"""
struct CIFAR10 <: SupervisedDataset
metadata::Dict{String, Any}
split::Symbol
features::Array{<:Any, 4}
targets::Vector{Int}
end
CIFAR10(; split=:train, Tx=Float32, dir=nothing) = CIFAR10(Tx, split; dir)
CIFAR10(split::Symbol; kws...) = CIFAR10(; split, kws...)
CIFAR10(Tx::Type; kws...) = CIFAR10(; Tx, kws...)
function CIFAR10(Tx::Type, split::Symbol; dir=nothing)
DEPNAME = "CIFAR10"
NCHUNKS = 5
TESTSET_FILENAME = joinpath("cifar-10-batches-bin", "test_batch.bin")
filename_for_chunk(file_index::Int) =
joinpath("cifar-10-batches-bin", "data_batch_$(file_index).bin")
@assert split ∈ (:train, :test)
if split == :train
# placeholders for the chunks
Xs = Vector{Array{UInt8,4}}(undef, NCHUNKS)
Ys = Vector{Vector{Int}}(undef, NCHUNKS)
# loop over all 5 trainingset files (i.e. chunks)
for file_index in 1:NCHUNKS
file_name = filename_for_chunk(file_index)
file_path = datafile(DEPNAME, file_name, dir)
# load all the data from each file and append it to
# the placeholders X and Y
X, Y = CIFAR10Reader.readdata(file_path)
Xs[file_index] = X
Ys[file_index] = Y
#TODO define a lazy version that reads a signle image only when asked
# file_index = ceil(Int, index / Reader.CHUNK_SIZE)
# file_name = filename_for_chunk(file_index)
# file_path = datafile(DEPNAME, file_name, dir)
## once we know the file we just need to compute the approriate
## offset of the image realtive to that file.
# sub_index = ((index - 1) % Reader.CHUNK_SIZE) + 1
# image, label = CIFAR10Reader.readdata(file_path, sub_index)
end
# cat all the placeholders into one image array
# and one label array. (good enough)
images = cat(Xs..., dims=4)::Array{UInt8,4}
labels = vcat(Ys...)::Vector{Int}
# optionally transform the image array before returning
features, targets = bytes_to_type(Tx, images), labels
else
file_path = datafile(DEPNAME, TESTSET_FILENAME, dir)
# simply read the complete content of the testset file
images, labels = CIFAR10Reader.readdata(file_path)
# optionally transform the image array before returning
features, targets = bytes_to_type(Tx, images), labels
end
metadata = Dict{String, Any}()
metadata["class_names"] = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
metadata["n_observations"] = size(features)[end]
return CIFAR10(metadata, split, features, targets)
end
convert2image(::Type{<:CIFAR10}, x::AbstractArray{<:Integer}) =
convert2image(CIFAR10, reinterpret(N0f8, convert(Array{UInt8}, x)))
function convert2image(::Type{<:CIFAR10}, x::AbstractArray{T,N}) where {T,N}
@assert N == 3 || N == 4
x = permutedims(x, (3, 2, 1, 4:N...))
checked_import(idImageCore).colorview(RGB, x)
# return ImageCore.colorview(RGB, x)
end
# DEPRECATED INTERFACE, REMOVE IN v0.7 (or 0.6.x)
function Base.getproperty(::Type{CIFAR10}, s::Symbol)
if s == :traintensor
@warn "CIFAR10.traintensor() is deprecated, use `CIFAR10(split=:train).features` instead."
traintensor(T::Type=N0f8; kws...) = traintensor(T, :; kws...)
traintensor(i; kws...) = traintensor(N0f8, i; kws...)
function traintensor(T::Type, i; dir=nothing)
CIFAR10(; split=:train, Tx=T, dir)[i][1]
end
return traintensor
elseif s == :testtensor
@warn "CIFAR10.testtensor() is deprecated, use `CIFAR10(split=:test).features` instead."
testtensor(T::Type=N0f8; kws...) = testtensor(T, :; kws...)
testtensor(i; kws...) = testtensor(N0f8, i; kws...)
function testtensor(T::Type, i; dir=nothing)
CIFAR10(; split=:test, Tx=T, dir)[i][1]
end
return testtensor
elseif s == :trainlabels
@warn "CIFAR10.trainlabels() is deprecated, use `CIFAR10(split=:train).targets` instead."
trainlabels(; kws...) = trainlabels(:; kws...)
function trainlabels(i; dir=nothing)
CIFAR10(; split=:train, dir)[i][2]
end
return trainlabels
elseif s == :testlabels
@warn "CIFAR10.testlabels() is deprecated, use `CIFAR10(split=:test).targets` instead."
testlabels(; kws...) = testlabels(:; kws...)
function testlabels(i; dir=nothing)
CIFAR10(; split=:test, dir)[i][2]
end
return testlabels
elseif s == :traindata
@warn "CIFAR10.traindata() is deprecated, use `CIFAR10(split=:train)[:]` instead."
traindata(T::Type=N0f8; kws...) = traindata(T, :; kws...)
traindata(i; kws...) = traindata(N0f8, i; kws...)
function traindata(T::Type, i; dir=nothing)
CIFAR10(; split=:train, Tx=T, dir)[i]
end
return traindata
elseif s == :testdata
@warn "CIFAR10.testdata() is deprecated, use `CIFAR10(split=:test)[:]` instead."
testdata(T::Type=N0f8; kws...) = testdata(T, :; kws...)
testdata(i; kws...) = testdata(N0f8, i; kws...)
function testdata(T::Type, i; dir=nothing)
CIFAR10(; split=:test, Tx=T, dir)[i]
end
return testdata
elseif s == :convert2image
@warn "CIFAR10.convert2image(x) is deprecated, use `convert2image(CIFAR10, x)` instead"
return x -> convert2image(CIFAR10, x)
else
return getfield(CIFAR10, s)
end
end