Training NesT from scatch (on CIFAR10) #751
-
Hi everyone, I am trying to replicate the results of the NesT paper ( https://arxiv.org/abs/2105.12723 ) using timm's implementation (see timm/models/nest.py ) adatpted to CIFAR10 (see details below). I tried to use the parameters of the original jax code given here which resulted in this Unfortunately, with these parameters, I only get around 89.5% of test accuracy for a NesT-Tiny (S=1, depth=4), which is far below the 95-96%-ish promised by the paper (Table 1 & Fig.6). Any ideas what may explain these differences ? (Sorry in advance for the trivial dissimilarities that I surely oversaw between the original jax and the current implementation/parameters. Thanks! Main changes to adapt timm's code to CIFAR-10:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
@cjsg having a glance at what you've done here are some things that might get you a lot closer:
Hope that helps, and curious to know how it goes. |
Beta Was this translation helpful? Give feedback.
@cjsg having a glance at what you've done here are some things that might get you a lot closer:
patch_size=1
. It's a little confusing when comparing to the initial implementation as there thepatch_size
refers to the size of a "block" (as the term is used in the paper) in units of "patches". To be clear, in terms ofofficial
->timm
it'spatch_size
->block_size
as determined hereinit_patch_embed_size
->patch_size
as set in themodel_kwargs
Some check sums to make sure we're on the same page: Your image size is 32x32. You set
num_levels
to 4 which means the first hierarchical level has 8x8 "blocks" each with 4x4 pixels. And your patch size is 1x1, …