Skip to content

Commit 01d8f90

Browse files
committed
init
0 parents  commit 01d8f90

16 files changed

+1195
-0
lines changed

.gitignore

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
data/*
2+
checkpoint/*
3+
logs/*
4+
others/*
5+
6+
*.pyc
7+
*.bak
8+
*.log
9+
*.tar
10+
*.pth

README.md

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
## Improving Generalization via Scalable Neighborhood Component Analysis
2+
3+
This repo constains the pytorch implementation for the ECCV2018 paper [(arxiv)](https://arxiv.org/pdf/.pdf).
4+
The project is about deep learning feature representations optimized for
5+
nearest neighbor classifiers, which may generalize to new object categories.
6+
7+
Much of code is borrowed from the previous [unsupervised learning project](https://arxiv.org/pdf/1805.01978.pdf).
8+
Please refer to [this repo](https://github.com/zhirongw/lemniscate.pytorch) for more details.
9+
10+
11+
## Pretrained Model
12+
13+
Currently, we provide 3 pretrained ResNet models.
14+
Each release contains the feature representation of all ImageNet training images (600 mb) and model weights (100-200mb).
15+
You can also get these representations by forwarding the network for the entire ImageNet images.
16+
17+
- [ResNet 18](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet18.pth.tar) (top 1 accuracy 70.59%)
18+
- [ResNet 34](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet34.pth.tar) (top 1 accuracy 74.41%)
19+
- [ResNet 50](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet50.pth.tar) (top 1 accuracy 76.57%)
20+
21+
## Nearest Neighbor
22+
23+
Please follow [this link](http://zhirongw.westus2.cloudapp.azure.com/nn.html) for a list of nearest neighbors on ImageNet.
24+
Results are visualized from our ResNet50 feature, compared with baseline ResNet50 feature, raw image features and supervised features.
25+
First column is the query image, followed by 20 retrievals ranked by the similarity.
26+
27+
## Usage
28+
29+
Our code extends the pytorch implementation of imagenet classification in [official pytorch release](https://github.com/pytorch/examples/tree/master/imagenet).
30+
Please refer to the official repo for details of data preparation and hardware configurations.
31+
32+
- install python2 and [pytorch=0.3](http://pytorch.org)
33+
34+
- clone this repo: `git clone https://github.com/zhirongw/snca.pytorch`
35+
36+
- Training on ImageNet:
37+
38+
`python main.py DATAPATH --arch resnet18 -j 32 --temperature 0.05 --low-dim 128 -b 256 `
39+
40+
- During training, we monitor the supervised validation accuracy by K nearest neighbor with k=1, as it's faster, and gives a good estimation of the feature quality.
41+
42+
- Testing on ImageNet:
43+
44+
`python main.py DATAPATH --arch resnet18 --resume input_model.pth.tar -e` runs testing with default K=30 neighbors.
45+
46+
- Training on CIFAR10:
47+
48+
`python cifar.py --nce-t 0.05 --lr 0.1`
49+
50+
51+
## Citation
52+
```
53+
@inproceedings{wu2018improving,
54+
title={Improving Generalization via Scalable Neighborhood Component Analysis},
55+
author={Wu, Zhirong and Efros, Alexei A and Yu, Stella},
56+
booktitle={European Conference on Computer Vision (ECCV) 2018},
57+
year={2018}
58+
}
59+
```
60+
61+
## Contact
62+
63+
For any questions, please feel free to reach
64+
```
65+
Zhirong Wu: [email protected]
66+
```

cifar.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
'''Train CIFAR10 with PyTorch.'''
2+
from __future__ import print_function
3+
4+
import sys
5+
import torch
6+
import torch.nn as nn
7+
import torch.optim as optim
8+
import torch.nn.functional as F
9+
import torch.backends.cudnn as cudnn
10+
11+
import torchvision
12+
import torchvision.transforms as transforms
13+
14+
import os
15+
import argparse
16+
import time
17+
18+
import models
19+
import datasets
20+
import math
21+
22+
from lib.LinearAverage import LinearAverage
23+
from lib.NCA import NCACrossEntropy
24+
from lib.utils import AverageMeter
25+
from test import NN, kNN
26+
27+
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
28+
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
29+
parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint')
30+
parser.add_argument('--test-only', action='store_true', help='test only')
31+
parser.add_argument('--low-dim', default=128, type=int,
32+
metavar='D', help='feature dimension')
33+
parser.add_argument('--temperature', default=0.05, type=float,
34+
metavar='T', help='temperature parameter for softmax')
35+
parser.add_argument('--memory-momentum', default=0.5, type=float,
36+
metavar='M', help='momentum for non-parametric updates')
37+
38+
args = parser.parse_args()
39+
40+
use_cuda = torch.cuda.is_available()
41+
best_acc = 0 # best test accuracy
42+
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
43+
44+
# Data
45+
print('==> Preparing data..')
46+
transform_train = transforms.Compose([
47+
#transforms.RandomCrop(32, padding=4),
48+
transforms.RandomResizedCrop(size=32, scale=(0.2,1.)),
49+
transforms.RandomGrayscale(p=0.2),
50+
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
51+
transforms.RandomHorizontalFlip(),
52+
transforms.ToTensor(),
53+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
54+
])
55+
56+
transform_test = transforms.Compose([
57+
transforms.ToTensor(),
58+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
59+
])
60+
61+
trainset = datasets.CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train)
62+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
63+
64+
testset = datasets.CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test)
65+
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
66+
67+
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
68+
ndata = trainset.__len__()
69+
70+
# Model
71+
if args.test_only or len(args.resume)>0:
72+
# Load checkpoint.
73+
print('==> Resuming from checkpoint..')
74+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
75+
checkpoint = torch.load('./checkpoint/'+args.resume)
76+
net = checkpoint['net']
77+
lemniscate = checkpoint['lemniscate']
78+
best_acc = checkpoint['acc']
79+
start_epoch = checkpoint['epoch']
80+
else:
81+
print('==> Building model..')
82+
net = models.__dict__['ResNet18'](low_dim=args.low_dim)
83+
# define leminiscate
84+
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum)
85+
86+
# define loss function
87+
criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.train_labels))
88+
89+
if use_cuda:
90+
net.cuda()
91+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
92+
lemniscate.cuda()
93+
criterion.cuda()
94+
cudnn.benchmark = True
95+
96+
if args.test_only:
97+
acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature)
98+
sys.exit(0)
99+
100+
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
101+
102+
def adjust_learning_rate(optimizer, epoch):
103+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
104+
lr = args.lr * (0.1 ** (epoch // 50))
105+
print(lr)
106+
for param_group in optimizer.param_groups:
107+
param_group['lr'] = lr
108+
109+
# Training
110+
def train(epoch):
111+
print('\nEpoch: %d' % epoch)
112+
adjust_learning_rate(optimizer, epoch)
113+
train_loss = AverageMeter()
114+
data_time = AverageMeter()
115+
batch_time = AverageMeter()
116+
correct = 0
117+
total = 0
118+
119+
# switch to train mode
120+
net.train()
121+
122+
end = time.time()
123+
for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
124+
data_time.update(time.time() - end)
125+
if use_cuda:
126+
inputs, targets, indexes = inputs.cuda(), targets.cuda(), indexes.cuda()
127+
optimizer.zero_grad()
128+
129+
features = net(inputs)
130+
outputs = lemniscate(features, indexes)
131+
loss = criterion(outputs, indexes)
132+
133+
loss.backward()
134+
optimizer.step()
135+
136+
train_loss.update(loss.item(), inputs.size(0))
137+
138+
# measure elapsed time
139+
batch_time.update(time.time() - end)
140+
end = time.time()
141+
142+
print('Epoch: [{}][{}/{}]'
143+
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
144+
'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
145+
'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
146+
epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss))
147+
148+
for epoch in range(start_epoch, start_epoch+200):
149+
train(epoch)
150+
acc = kNN(epoch, net, lemniscate, trainloader, testloader, 30, args.temperature)
151+
152+
if acc > best_acc:
153+
print('Saving..')
154+
state = {
155+
'net': net.module if use_cuda else net,
156+
'lemniscate': lemniscate,
157+
'acc': acc,
158+
'epoch': epoch,
159+
}
160+
if not os.path.isdir('checkpoint'):
161+
os.mkdir('checkpoint')
162+
torch.save(state, './checkpoint/ckpt.t7')
163+
best_acc = acc
164+
165+
print('best accuracy: {:.2f}'.format(best_acc*100))

datasets/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .folder import ImageFolderInstance
2+
from .cifar import CIFAR10Instance, CIFAR100Instance
3+
4+
__all__ = ('ImageFolderInstance', 'CIFAR10Instance', 'CIFAR100Instance')
5+

datasets/cifar.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import print_function
2+
from PIL import Image
3+
import torchvision.datasets as datasets
4+
import torch.utils.data as data
5+
6+
class CIFAR10Instance(datasets.CIFAR10):
7+
"""CIFAR10Instance Dataset.
8+
"""
9+
def __getitem__(self, index):
10+
if self.train:
11+
img, target = self.train_data[index], self.train_labels[index]
12+
else:
13+
img, target = self.test_data[index], self.test_labels[index]
14+
15+
# doing this so that it is consistent with all other datasets
16+
# to return a PIL Image
17+
img = Image.fromarray(img)
18+
19+
if self.transform is not None:
20+
img = self.transform(img)
21+
22+
if self.target_transform is not None:
23+
target = self.target_transform(target)
24+
25+
return img, target, index
26+
27+
class CIFAR100Instance(CIFAR10Instance):
28+
"""CIFAR100Instance Dataset.
29+
30+
This is a subclass of the `CIFAR10Instance` Dataset.
31+
"""
32+
base_folder = 'cifar-100-python'
33+
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
34+
filename = "cifar-100-python.tar.gz"
35+
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
36+
train_list = [
37+
['train', '16019d7e3df5f24257cddd939b257f8d'],
38+
]
39+
40+
test_list = [
41+
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
42+
]

datasets/folder.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torchvision.datasets as datasets
2+
3+
class ImageFolderInstance(datasets.ImageFolder):
4+
""": Folder datasets which returns the index of the image as well::
5+
"""
6+
def __getitem__(self, index):
7+
"""
8+
Args:
9+
index (int): Index
10+
Returns:
11+
tuple: (image, target) where target is class_index of the target class.
12+
"""
13+
path, target = self.imgs[index]
14+
img = self.loader(path)
15+
if self.transform is not None:
16+
img = self.transform(img)
17+
if self.target_transform is not None:
18+
target = self.target_transform(target)
19+
20+
return img, target, index
21+

lib/LinearAverage.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
from torch.autograd import Function
3+
from torch import nn
4+
import math
5+
6+
class LinearAverageOp(Function):
7+
@staticmethod
8+
def forward(self, x, y, memory, params):
9+
T = params[0].item()
10+
batchSize = x.size(0)
11+
12+
# inner product
13+
out = torch.mm(x.data, memory.t())
14+
out.div_(T) # batchSize * N
15+
16+
self.save_for_backward(x, memory, y, params)
17+
18+
return out
19+
20+
@staticmethod
21+
def backward(self, gradOutput):
22+
x, memory, y, params = self.saved_tensors
23+
batchSize = gradOutput.size(0)
24+
T = params[0].item()
25+
momentum = params[1].item()
26+
27+
# add temperature
28+
gradOutput.data.div_(T)
29+
30+
# gradient of linear
31+
gradInput = torch.mm(gradOutput.data, memory)
32+
gradInput.resize_as_(x)
33+
34+
# update the non-parametric data
35+
weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
36+
weight_pos.mul_(momentum)
37+
weight_pos.add_(torch.mul(x.data, 1-momentum))
38+
w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
39+
updated_weight = weight_pos.div(w_norm)
40+
memory.index_copy_(0, y, updated_weight)
41+
42+
return gradInput, None, None, None
43+
44+
class LinearAverage(nn.Module):
45+
46+
def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5):
47+
super(LinearAverage, self).__init__()
48+
stdv = 1 / math.sqrt(inputSize)
49+
self.nLem = outputSize
50+
51+
self.register_buffer('params',torch.tensor([T, momentum]));
52+
stdv = 1. / math.sqrt(inputSize/3)
53+
self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
54+
55+
def forward(self, x, y):
56+
out = LinearAverageOp.apply(x, y, self.memory, self.params)
57+
return out
58+

0 commit comments

Comments
 (0)