Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lazy module loading #128

Merged
merged 17 commits into from
May 20, 2022
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458"
version = "0.6.0"

[deps]
BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Expand All @@ -26,12 +25,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
BinDeps = "1"
CSV = "0.10.2"
ColorTypes = "0.11"
DataDeps = "0.7"
DataFrames = "1.3"
FileIO = "1.13"
FileIO = "1.14"
FixedPointNumbers = "0.8"
GZip = "0.5"
Glob = "1.3"
Expand All @@ -41,8 +39,8 @@ JLD2 = "0.4.21"
JSON3 = "1"
MAT = "0.10"
MLUtils = "0.2.0"
Pickle = "0.3"
NPZ = "0.4.1"
Pickle = "0.3"
Requires = "1"
Tables = "1.6"
julia = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ makedocs(
"Text" => "datasets/text.md",
"Vision" => "datasets/vision.md",
],
"Creating Datasets" => Any["containers/overview.md"],
# "Creating Datasets" => Any["containers/overview.md"], # still experimental
"LICENSE.md",
],
strict = true,
Expand Down
47 changes: 26 additions & 21 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@ module MLDatasets

using FixedPointNumbers
using SparseArrays
using DataFrames, Tables
using Tables
using Glob
import ImageCore
using ColorTypes
# using DataFrames
# import ImageCore
using DataDeps
import MLUtils
using MLUtils: getobs, numobs, AbstractDataContainer
using ColorTypes

### I/O imports
import NPZ
import Pickle
using MAT: matopen, matread
import CSV
using HDF5
using JLD2
import JSON3
# import NPZ
# import Pickle
# using MAT: matopen, matread
using FileIO
# import CSV
# using HDF5
# using JLD2
# import JSON3
using DelimitedFiles: readdlm
##########

Expand All @@ -29,24 +32,26 @@ include("abstract_datasets.jl")
# export AbstractDataset,
# SupervisedDataset

include("imports.jl")
include("utils.jl")
export convert2image

include("io.jl")
# export read_csv, read_npy
# export read_csv, read_npy, ...

include("download.jl")

include("containers/filedataset.jl")
export FileDataset
include("containers/tabledataset.jl")
export TableDataset
include("containers/hdf5dataset.jl")
export HDF5Dataset
include("containers/jld2dataset.jl")
export JLD2Dataset
include("containers/cacheddataset.jl")
export CachedDataset
### API to be revisited with conditional module loading
# include("containers/filedataset.jl")
# export FileDataset
# include("containers/tabledataset.jl")
# export TableDataset
# include("containers/hdf5dataset.jl")
# export HDF5Dataset
# # include("containers/jld2dataset.jl")
# # export JLD2Dataset
# include("containers/cacheddataset.jl")
# export CachedDataset

# Misc.
include("datasets/misc/boston_housing.jl")
Expand Down
21 changes: 15 additions & 6 deletions src/abstract_datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ function leftalign(s::AbstractString, n::Int)
end
end

_summary(x) = x
_summary(x) = Tables.istable(x) ? summary(x) : x
_summary(x::Symbol) = ":$x"
_summary(x::Union{Dict, AbstractArray, DataFrame}) = summary(x)
_summary(x::Union{Tuple, NamedTuple}) = map(_summary, x)
_summary(x::Dict) = summary(x)
_summary(x::Tuple) = map(_summary, x)
_summary(x::NamedTuple) = map(_summary, x)
_summary(x::AbstractArray) = summary(x)
_summary(x::BitVector) = "$(count(x))-trues BitVector"

"""
Expand All @@ -58,11 +60,18 @@ a `features` and a `targets` fields.
abstract type SupervisedDataset <: AbstractDataset end


Base.length(d::SupervisedDataset) = numobs((d.features, d.targets))
Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) :
numobs((d.features, d.targets))


# We return named tuples
Base.getindex(d::SupervisedDataset, ::Colon) = getobs((; d.features, d.targets))
Base.getindex(d::SupervisedDataset, i) = getobs((; d.features, d.targets), i)
Base.getindex(d::SupervisedDataset, ::Colon) = Tables.istable(d.features) ?
(features = d.features, targets=d.targets) :
getobs((; d.features, d.targets))

Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ?
(features = getobs_table(d.features, i), targets=getobs_table(d.targets, i)) :
getobs((; d.features, d.targets), i)

"""
UnsupervisedDataset <: AbstractDataset
Expand Down
6 changes: 3 additions & 3 deletions src/containers/tabledataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct TableDataset{T} <: AbstractDataContainer
end

TableDataset(table::T) where {T} = TableDataset{T}(table)
TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path)))
# TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path)))

# slow accesses based on Tables.jl
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
Expand Down Expand Up @@ -54,8 +54,8 @@ Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :]
Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)

# fast access for CSV.File
Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i]
Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table)
# Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i]
# Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table)

## Tables.jl interface

Expand Down
13 changes: 1 addition & 12 deletions src/datasets/graphs/planetoid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,12 @@ function read_planetoid_data(DEPNAME; dir=nothing, reverse_edges=true)
return metadata, g
end

function read_pickle_file(filename, name)
out = Pickle.npyload(filename)
if name == "graph"
return out
end
if out isa SparseMatrixCSC
return Matrix(out)
end
return out
end

function read_planetoid_file(DEPNAME, name, dir)
filename = datafile(DEPNAME, name, dir)
if endswith(name, "test.index")
out = 1 .+ vec(readdlm(filename, Int))
else
out = read_pickle_file(filename, name)
out = read_pickle(filename)
if out isa SparseMatrixCSC
out = Matrix(out)
end
Expand Down
6 changes: 3 additions & 3 deletions src/datasets/graphs/reddit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ function Reddit(; full=true, dir=nothing)
feat_path = datafile(DEPNAME, DATA[5], dir)

# Read the json files
graph = open(JSON3.read, graph_json)
class_map = open(JSON3.read, class_map_json)
id_map = open(JSON3.read, id_map_json)
graph = read_json(graph_json)
class_map = read_json(class_map_json)
id_map = read_json(id_map_json)

# Metadata
directed = graph["directed"]
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/misc/boston_housing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ function BostonHousing(; as_df = true, dir = nothing)
@assert dir === nothing "custom `dir` is not supported at the moment."
path = joinpath(@__DIR__, "..", "..", "..", "data", "boston_housing.csv")
df = read_csv(path)
features = df[!, Not(:MEDV)]
DFs = checked_import(idDataFrames)
features = df[!, DFs.Not(:MEDV)]
targets = df[!, [:MEDV]]

metadata = Dict{String, Any}()
Expand Down
5 changes: 3 additions & 2 deletions src/datasets/misc/iris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ end
function Iris(; dir = nothing, as_df = true)
path = datafile("Iris", "iris.data", dir)
df = read_csv(path, header=0)
rename!(df, ["sepallength", "sepalwidth", "petallength", "petalwidth", "class"])
DFs = checked_import(idDataFrames)
DFs.rename!(df, ["sepallength", "sepalwidth", "petallength", "petalwidth", "class"])

features = df[!, Not(:class)]
features = df[!, DFs.Not(:class)]
targets = df[!, [:class]]

metadata = Dict{String, Any}()
Expand Down
8 changes: 6 additions & 2 deletions src/datasets/misc/mutagenesis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ function Mutagenesis(split::Symbol; dir=nothing)

data_path = datafile(DEPNAME, DATA, dir)
metadata_path = datafile(DEPNAME, METADATA, dir)
samples = open(JSON3.read, data_path)
metadata = open(JSON3.read, metadata_path)
samples = read_json(data_path)
metadata = read_json(metadata_path)
labelkey = metadata["label"]
targets = map(i -> i[labelkey], samples)
features = map(x->delete!(copy(x), Symbol(labelkey)), samples)
Expand All @@ -101,6 +101,10 @@ function Mutagenesis(split::Symbol; dir=nothing)
Mutagenesis(metadata, split, indexes, features[indexes], targets[indexes])
end

Base.length(d::Mutagenesis, ::Colon) = numobs((; d.features, d.targets))
Base.getindex(d::Mutagenesis, ::Colon) = getobs((; d.features, d.targets))
Base.getindex(d::Mutagenesis, i) = getobs((; d.features, d.targets), i)

# deprecated in v0.6
function Base.getproperty(::Type{Mutagenesis}, s::Symbol)
if s == :traindata
Expand Down
4 changes: 3 additions & 1 deletion src/datasets/misc/titanic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ function Titanic(; as_df = true, dir = nothing)
@assert dir === nothing "custom `dir` is not supported at the moment."
path = joinpath(@__DIR__, "..", "..", "..", "data", "titanic.csv")
df = read_csv(path)
features = df[!, Not(:Survived)]
DFs = checked_import(idDataFrames)

features = df[!, DFs.Not(:Survived)]
targets = df[!, [:Survived]]

metadata = Dict{String, Any}()
Expand Down
5 changes: 3 additions & 2 deletions src/datasets/vision/cifar10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function __init__cifar10()
""",
"https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
"c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd",
post_fetch_method = file -> (run(BinDeps.unpack_cmd(file, dirname(file), ".gz", ".tar")); rm(file))
post_fetch_method = DataDeps.unpack
))
end

Expand Down Expand Up @@ -165,7 +165,8 @@ convert2image(::Type{<:CIFAR10}, x::AbstractArray{<:Integer}) =
function convert2image(::Type{<:CIFAR10}, x::AbstractArray{T,N}) where {T,N}
@assert N == 3 || N == 4
x = permutedims(x, (3, 2, 1, 4:N...))
return ImageCore.colorview(RGB, x)
checked_import(idImageCore).colorview(RGB, x)
# return ImageCore.colorview(RGB, x)
end


Expand Down
2 changes: 1 addition & 1 deletion src/datasets/vision/cifar100.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function __init__cifar100()
""",
"https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz",
"58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec",
post_fetch_method = file -> (run(BinDeps.unpack_cmd(file, dirname(file), ".gz", ".tar")); rm(file))
post_fetch_method = DataDeps.unpack
))
end

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/vision/emnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function __init__emnist()
""",
"http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/matlab.zip",
"e1fa805cdeae699a52da0b77c2db17f6feb77eed125f9b45c022e7990444df95",
post_fetch_method = file -> (run(BinDeps.unpack_cmd(file, dirname(file), ".zip", "")); rm(file))
post_fetch_method = DataDeps.unpack
))
end

Expand Down Expand Up @@ -119,7 +119,7 @@ function EMNIST(name, Tx::Type, split::Symbol; dir=nothing)
path = "matlab/emnist-$name.mat"

path = datafile("EMNIST", path, dir)
vars = matread(path)
vars = read_mat(path)
features = reshape(vars["dataset"]["$split"]["images"], :, 28, 28)
features = permutedims(features, (3, 2, 1))
targets = Int.(vars["dataset"]["$split"]["labels"] |> vec)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/vision/mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ convert2image(::Type{<:MNIST}, x::AbstractArray{<:Integer}) =
function convert2image(::Type{<:MNIST}, x::AbstractArray{T,N}) where {T,N}
@assert N == 2 || N == 3
x = permutedims(x, (2, 1, 3:N...))
return ImageCore.colorview(Gray, x)
checked_import(idImageCore).colorview(Gray, x)
end

# DEPRECATED INTERFACE, REMOVE IN v0.7 (or 0.6.x)
Expand Down
1 change: 0 additions & 1 deletion src/datasets/vision/mnist_reader/MNISTReader.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module MNISTReader
using GZip
using BinDeps

export
readimages,
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/vision/svhn2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function SVHN2(Tx::Type, split::Symbol; dir=nothing)
end

path = datafile(DEPNAME, PATH, dir)
vars = matread(path)
vars = read_mat(path)
images = vars["X"]::Array{UInt8,4}
labels = vars["y"]
images = permutedims(images, (2, 1, 3, 4))
Expand Down
7 changes: 2 additions & 5 deletions src/download.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import BinDeps
using DataDeps

function with_accept(f, manual_overwrite)
auto_accept = if manual_overwrite == nothing
auto_accept = if manual_overwrite === nothing
get(ENV, "DATADEPS_ALWAYS_ACCEPT", false)
else
manual_overwrite
Expand All @@ -12,7 +9,7 @@ end

function datadir(depname, dir = nothing; i_accept_the_terms_of_use = nothing)
with_accept(i_accept_the_terms_of_use) do
if dir == nothing
if dir === nothing
# use DataDeps defaults
@datadep_str depname
else
Expand Down
39 changes: 39 additions & 0 deletions src/imports.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

const idDataFrames = Base.PkgId(Base.UUID("a93c6f00-e57d-5684-b7b6-d8193f3e46c0"), "DataFrames")
const idCSV = Base.PkgId(Base.UUID("336ed68f-0bac-5ca0-87d4-7b16caf5d00b"), "CSV")
const idImageCore = Base.PkgId(Base.UUID("a09fc81d-aa75-5fe9-8630-4744c3626534"), "ImageCore")
const idPickle = Base.PkgId(Base.UUID("fbb45041-c46e-462f-888f-7c521cafbc2c"), "Pickle")
const idHDF5 = Base.PkgId(Base.UUID("f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"), "HDF5")
const idMAT = Base.PkgId(Base.UUID("23992714-dd62-5051-b70f-ba57cb901cac"), "MAT")
const idJSON3 = Base.PkgId(Base.UUID("0f8b85d8-7281-11e9-16c2-39a750bddbf1"), "JSON3")
# ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
# JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"

const load_locker = Threads.ReentrantLock()

function checked_import(pkgid)
mod = if Base.root_module_exists(pkgid)
Base.root_module(pkgid)
else
lock(load_locker) do
Base.require(pkgid)
end
end

return ImportedModule(mod)
end

struct ImportedModule
mod::Module
end

function Base.getproperty(m::ImportedModule, s::Symbol)
if s == :mod
return getfield(m, s)
else
function f(args...; kws...)
Base.invokelatest(getproperty(m.mod, s), args...; kws...)
end
return f
end
end
Loading