diff --git a/configs/byol/byol_r50_IM.yaml b/configs/byol/byol_r50_IM.yaml index 56ae2970..a8a59fb8 100644 --- a/configs/byol/byol_r50_IM.yaml +++ b/configs/byol/byol_r50_IM.yaml @@ -1,5 +1,5 @@ epochs: 300 -use_byol_iters: True +use_simclr_iters: True total_images: 1281167 global_batch_size: 4096 output_dir: output_dir @@ -84,12 +84,13 @@ dataloader: lr_scheduler: - name: CosineWarmup - learning_rate: 4.8 - T_max: 93835 - warmup_steps: 3127 - start_lr: 0.0001 - end_lr: 4.8 + name: simclrCosineWarmup + learning_rate_scaling: linear + total_images: 1281167 + warmup_epochs: 10 + start_lr: 0 + end_lr: 0.3 + T_max: 300 optimizer: diff --git a/configs/simclr/simclr_r50_IM.yaml b/configs/simclr/simclr_r50_IM.yaml index d555021c..54e0d339 100755 --- a/configs/simclr/simclr_r50_IM.yaml +++ b/configs/simclr/simclr_r50_IM.yaml @@ -2,6 +2,7 @@ epochs: 100 use_simclr_iters: True global_batch_size: 4096 output_dir: output_dir +device: gpu model: name: SimCLR @@ -21,7 +22,9 @@ model: dataloader: train: - num_workers: 6 + loader: + num_workers: 4 + use_shared_memory: True sampler: batch_size: 32 shuffle: true @@ -83,9 +86,11 @@ dataloader: std: [0.229, 0.224, 0.225] val: - num_workers: 4 + loader: + num_workers: 4 + use_shared_memory: True sampler: - batch_size: 512 + batch_size: 256 shuffle: false drop_last: false dataset: @@ -105,18 +110,18 @@ dataloader: lr_scheduler: name: simclrCosineWarmup - learning_rate_scaling: sqrt + learning_rate_scaling: linear total_images: 1281167 warmup_epochs: 10 start_lr: 0 - end_lr: 1.0 + end_lr: 0.3 T_max: 200 optimizer: name: LarsMomentumOptimizer momentum: 0.9 - lars_weight_decay: 0.0001 + lars_weight_decay: 1e-6 exclude_from_weight_decay: ["scale","offset",".bias"] log_config: diff --git a/passl/engine/trainer.py b/passl/engine/trainer.py index 649459a6..296ea737 100644 --- a/passl/engine/trainer.py +++ b/passl/engine/trainer.py @@ -145,21 +145,17 @@ def __init__(self, cfg): self.train_dataloader, self.mixup_fn = build_dataloader( cfg.dataloader.train, self.device) self.iters_per_epoch = len(self.train_dataloader) - + self.batch_size = cfg.dataloader.train.sampler.batch_size + self.global_batch_size = self.batch_size * dist.get_world_size() # use byol iters if self.use_byol_iters: - self.global_batch_size = cfg.global_batch_size self.byol_total_iters = self.epochs * cfg.total_images // self.global_batch_size - - if self.use_byol_iters: self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler, self.byol_total_iters) elif self.use_simclr_iters: - self.batch_size = cfg.dataloader.train.sampler.batch_size - self.global_batch_size = cfg.global_batch_size self.epochs = cfg.epochs self.lr_scheduler = build_lr_scheduler_simclr( - cfg.lr_scheduler, self.iters_per_epoch, self.batch_size * 8, + cfg.lr_scheduler, self.iters_per_epoch, self.global_batch_size, cfg.epochs, self.current_iter) else: self.lr_scheduler = build_lr_scheduler(cfg.lr_scheduler, @@ -224,7 +220,7 @@ def __init__(self, cfg): self.add_train_hooks() self.add_custom_hooks() self.hooks = sorted(self.hooks, key=lambda x: x.priority) - + print("hooks: ", self.hooks) if self.epochs: self.total_iters = self.epochs * self.iters_per_epoch self.by_epoch = True diff --git a/passl/modeling/architectures/BYOL.py b/passl/modeling/architectures/BYOL.py index 01959526..236dd976 100644 --- a/passl/modeling/architectures/BYOL.py +++ b/passl/modeling/architectures/BYOL.py @@ -33,6 +33,7 @@ import paddle import paddle.fluid.layers as layers + def single_random_gaussian_blur(image, height, width, p=1.0): """Randomly blur an image. Args: @@ -53,22 +54,23 @@ def single_random_gaussian_blur(image, height, width, p=1.0): x = paddle.arange(-radius, radius + 1, 1, "float32") blur_filter = paddle.exp(-paddle.pow(x, 2.0) / (2.0 * paddle.pow(sigma, 2.0))) - blur_filter /= layers.reduce_sum(blur_filter) - blur_v = layers.reshape(blur_filter, [1, 1, kernel_size, 1]) - blur_h = layers.reshape(blur_filter, [1, 1, 1, kernel_size]) + blur_filter /= layers.nn.reduce_sum(blur_filter) + blur_v = paddle.reshape(blur_filter, [1, 1, kernel_size, 1]) + blur_h = paddle.reshape(blur_filter, [1, 1, 1, kernel_size]) num_channels = 3 blur_h = paddle.tile(blur_h, [num_channels, 1, 1, 1]) blur_v = paddle.tile(blur_v, [num_channels, 1, 1, 1]) - + expand_batch_dim = len(image.shape) == 3 if expand_batch_dim: - image = paddle.unsqueeze(image.transpose((2,0,1)), axis=0) + image = paddle.unsqueeze(image.transpose((2, 0, 1)), axis=0) blurred = paddle.nn.functional.conv2d( - image, blur_h, stride=1, padding=padding,groups=3) + image, blur_h, stride=1, padding=padding, groups=3) blurred = paddle.nn.functional.conv2d( - blurred, blur_v, stride=1, padding=padding,groups=3) - return blurred.transpose((0,2,3,1)) + blurred, blur_v, stride=1, padding=padding, groups=3) + return blurred.transpose((0, 2, 3, 1)) + def random_gaussian_blur(image, height, width, p=1.0): """Randomly blur an image. @@ -82,26 +84,29 @@ def random_gaussian_blur(image, height, width, p=1.0): """ res = [] for i in range(image.shape[0]): - res.append(single_random_gaussian_blur(image[i],height,width,p)) - return paddle.concat(res,axis=0) + res.append(single_random_gaussian_blur(image[i], height, width, p)) + return paddle.concat(res, axis=0) -def random_solarization(img,threshold=0.5): - img = paddle.where(img < threshold, img, 1 -img) + +def random_solarization(img, threshold=0.5): + img = paddle.where(img < threshold, img, 1 - img) return img -def img_normalize(img,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]): + +def img_normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): mean = paddle.to_tensor(mean, dtype='float32').reshape([1, 1, 1, 3]) std = paddle.to_tensor(std, dtype='float32').reshape([1, 1, 1, 3]) return (img - mean) / std + def to_chw(img): - return img.transpose((0,3,1,2)) + return img.transpose((0, 3, 1, 2)) -def batch_random_blur_solariza_normalize_chw( - view1, - view2, - blur_probability=(1.0,0.1), - solariza_probability=(0.0,0.2) ): + +def batch_random_blur_solariza_normalize_chw(view1, + view2, + blur_probability=(1.0, 0.1), + solariza_probability=(0.0, 0.2)): """Apply efficient batch data transformations. Args: images_list: a list of image tensors. @@ -114,23 +119,23 @@ def batch_random_blur_solariza_normalize_chw( def generate_selector(p, bsz): shape = [bsz, 1, 1, 1] - p_tensor = layers.fill_constant( + p_tensor = paddle.tensor.fill_constant( shape=shape, dtype="float32", value=p) - selector = layers.cast( - layers.less_than( - layers.uniform_random( + selector = paddle.cast( + paddle.less_than( + paddle.uniform( shape=shape, min=0, max=1, dtype="float32"), p_tensor), "float32") return selector - - B,H,W,C = view1.shape + + B, H, W, C = view1.shape img1 = view1 img1_new = random_gaussian_blur(img1, H, W, p=1.0) - selector = generate_selector(blur_probability[0],B) + selector = generate_selector(blur_probability[0], B) img1_blur_res = img1_new * selector + img1 * (1 - selector) - - selector = generate_selector(solariza_probability[0],B) + + selector = generate_selector(solariza_probability[0], B) img1_sola_res = random_solarization(img1_blur_res) img1_sola_res = img1_sola_res * selector + img1_blur_res * (1 - selector) img1_sola_res = paddle.clip(img1_sola_res, min=0., max=1.) @@ -140,24 +145,26 @@ def generate_selector(p, bsz): img2 = view2 img2_new = random_gaussian_blur(img2, H, W, p=1.0) - selector = generate_selector(blur_probability[1],B) + selector = generate_selector(blur_probability[1], B) img2_blur_res = img2_new * selector + img2 * (1 - selector) - - selector = generate_selector(solariza_probability[1],B) + + selector = generate_selector(solariza_probability[1], B) img2_sola_res = random_solarization(img2_blur_res) img2_sola_res = img2_sola_res * selector + img2_blur_res * (1 - selector) img2_sola_res = paddle.clip(img2_sola_res, min=0., max=1.) - img2_sola_res.stop_gradient = True + img2_sola_res.stop_gradient = True img2_tran_res = to_chw(img_normalize(img2_sola_res)) return img1_tran_res, img2_tran_res + @MODELS.register() class BYOL(nn.Layer): """ - Build a MoCo model with: a query encoder, a key encoder, and a queue - https://arxiv.org/abs/1911.05722 + Build a BYOL model referenced from paper + https://arxiv.org/abs/2006.07733 """ + def __init__(self, backbone, neck=None, @@ -169,8 +176,7 @@ def __init__(self, target_decay_method='fixed', target_decay_rate=0.996, align_init_network=True, - use_synch_bn=False - ): + use_synch_bn=False): """ Args: backbone (dict): config of backbone. @@ -184,77 +190,82 @@ def __init__(self, self.towers = nn.LayerList() self.base_m = target_decay_rate self.target_decay_method = target_decay_method - + neck1 = build_neck(neck) neck2 = build_neck(neck) - + self.towers.append(nn.Sequential(build_backbone(backbone), neck1)) self.towers.append(nn.Sequential(build_backbone(backbone), neck2)) self.net_init(self.towers) self.predictor = build_neck(predictor) self.net_init(self.predictor) - self.classifier = nn.Linear(embedding_dim,num_classes) + self.classifier = nn.Linear(embedding_dim, num_classes) self.net_init(self.classifier) self.backbone = self.towers[0][0] # self.neck1 = self.towers[0][1] # TODO IMPORTANT! Explore if the initialization requires to be synchronized - for param_q, param_k in zip(self.towers[0].parameters(),self.towers[1].parameters()): + for param_q, param_k in zip(self.towers[0].parameters(), + self.towers[1].parameters()): param_k.stop_gradient = True if align_init_network: - for param_q, param_k in zip(self.towers[0].parameters(),self.towers[1].parameters()): + for param_q, param_k in zip(self.towers[0].parameters(), + self.towers[1].parameters()): param_k.set_value(param_q) # initialize - + # Convert BatchNorm*d to SyncBatchNorm*d if use_synch_bn: - self.towers[0] = nn.SyncBatchNorm.convert_sync_batchnorm(self.towers[0]) - self.towers[1] = nn.SyncBatchNorm.convert_sync_batchnorm(self.towers[1]) + self.towers[0] = nn.SyncBatchNorm.convert_sync_batchnorm( + self.towers[0]) + self.towers[1] = nn.SyncBatchNorm.convert_sync_batchnorm( + self.towers[1]) #self.predictor = nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor) self.head = build_head(head) - - def net_init(self,network): + + def net_init(self, network): for m in network.sublayers(): if isinstance(m, nn.Conv2D): - init.kaiming_init(m,mode="fan_in",nonlinearity="conv2d") + init.kaiming_init(m, mode="fan_in", nonlinearity="conv2d") if isinstance(m, nn.Conv2D): - init.kaiming_init(m,mode="fan_in",nonlinearity="linear") + init.kaiming_init(m, mode="fan_in", nonlinearity="linear") def train_iter(self, *inputs, **kwargs): - + current_iter = kwargs['current_iter'] - total_iters = kwargs['total_iters'] - + total_iters = kwargs['total_iters'] + if self.target_decay_method == 'cosine': - self.m = 1 - (1-self.base_m) * (1 + math.cos(math.pi*(current_iter-0)/total_iters))/2.0 # 47.0 + self.m = 1 - (1 - self.base_m) * (1 + math.cos(math.pi * ( + current_iter - 0) / total_iters)) / 2.0 # 47.0 elif self.target_decay_method == 'fixed': - self.m = self.base_m # 55.7 + self.m = self.base_m # 55.7 else: raise NotImplementedError # self.update_target_network() img_a, img_b, label = inputs - img_a, img_b = batch_random_blur_solariza_normalize_chw(img_a,img_b) + img_a, img_b = batch_random_blur_solariza_normalize_chw(img_a, img_b) embedding = self.towers[0][0](img_a) online_project_view1 = self.towers[0][1](embedding) online_predict_view1 = self.predictor(online_project_view1) online_project_view2 = self.towers[0](img_b) online_predict_view2 = self.predictor(online_project_view2) - + clone_x = embedding.clone() - clone_x.stop_gradient = True + clone_x.stop_gradient = True classif_out = self.classifier(clone_x.squeeze()) - + with paddle.no_grad(): target_project_view1 = self.towers[1](img_a).clone().detach() target_project_view2 = self.towers[1](img_b).clone().detach() a1 = nn.functional.normalize(online_predict_view1, axis=1) b1 = nn.functional.normalize(target_project_view2, axis=1) - b1.stop_gradient = True + b1.stop_gradient = True a2 = nn.functional.normalize(online_predict_view2, axis=1) b2 = nn.functional.normalize(target_project_view1, axis=1) @@ -286,7 +297,9 @@ def update_target_network(self): def update_target_network_L1(self): for param_q, param_k in zip(self.towers[0].parameters(), self.towers[1].parameters()): - paddle.assign(param_k - (1-self.m)*paddle.sign(param_k-param_q), param_k) + paddle.assign(param_k - + (1 - self.m) * paddle.sign(param_k - param_q), + param_k) param_k.stop_gradient = True # L2 + L1 @@ -294,7 +307,10 @@ def update_target_network_L1(self): def update_target_network_clip(self): for param_q, param_k in zip(self.towers[0].parameters(), self.towers[1].parameters()): - paddle.assign(param_k - (1-self.m) * paddle.clip((param_k - param_q), min=-1.0, max=1.0) , param_k) + paddle.assign( + param_k - (1 - self.m) * paddle.clip( + (param_k - param_q), min=-1.0, max=1.0), + param_k) param_k.stop_gradient = True @paddle.no_grad() @@ -302,5 +318,8 @@ def update_target_network_LN_clip(self): for param_q, param_k in zip(self.towers[0].parameters(), self.towers[1].parameters()): paddle.assign((param_k * self.m + param_q * (1. - self.m)), param_k) - paddle.assign(param_k - (1-self.m) * paddle.clip((param_k - param_q), min=-1.0, max=1.0) , param_k) + paddle.assign( + param_k - (1 - self.m) * paddle.clip( + (param_k - param_q), min=-1.0, max=1.0), + param_k) param_k.stop_gradient = True diff --git a/passl/modeling/architectures/simclr.py b/passl/modeling/architectures/simclr.py index 98d78b3b..f94623b8 100755 --- a/passl/modeling/architectures/simclr.py +++ b/passl/modeling/architectures/simclr.py @@ -23,42 +23,35 @@ import paddle.nn.functional as F import paddle.fluid.layers as layers - LARGE_NUM = 1e9 + @MODELS.register() class SimCLR(nn.Layer): """ Simple image SimCLR. """ - def __init__(self, - backbone, - neck=None, - head=None, - dim=128, - T=0.5): + def __init__(self, backbone, neck=None, head=None, dim=128, T=0.5): super(SimCLR, self).__init__() self.T = T - self.encoder = nn.Sequential(build_backbone(backbone), - build_neck(neck)) - + self.encoder = nn.Sequential(build_backbone(backbone), build_neck(neck)) + self.backbone = self.encoder[0] self.head = build_head(head) - - def train_iter(self, *inputs, **kwargs): img_q, img_k = inputs img_con = [img_q, img_k] img_con = paddle.concat(img_con) con = self.encoder(img_con) - con = layers.l2_normalize(con, -1) - q, k = layers.split(con, num_or_sections=2, dim=0) + con = paddle.nn.functional.normalize(con, axis=-1) + q, k = paddle.split(con, num_or_sections=2, axis=0) outputs = self.head(q, k) - + return outputs + def test_iter(self, *inputs, **kwargs): with paddle.no_grad(): img, label = inputs @@ -76,6 +69,3 @@ def forward(self, *inputs, mode='train', **kwargs): return self.backbone(*inputs) else: raise Exception("No such mode: {}".format(mode)) - - - diff --git a/passl/modeling/heads/simclr_contrastive_head.py b/passl/modeling/heads/simclr_contrastive_head.py index 3906d4ac..96fc50f6 100755 --- a/passl/modeling/heads/simclr_contrastive_head.py +++ b/passl/modeling/heads/simclr_contrastive_head.py @@ -55,22 +55,24 @@ def forward(self, pos, neg): hidden1_large = hidden1 hidden2_large = hidden2 labels = F.one_hot( - paddle.reshape(paddle.arange(0, batch_size, 1, "int32"), - [batch_size]), batch_size * 2) + paddle.reshape( + paddle.arange(0, batch_size, 1, "int32"), [batch_size]), + batch_size * 2) masks = F.one_hot( - paddle.reshape(paddle.arange(0, batch_size, 1, "int32"), - [batch_size]), batch_size) + paddle.reshape( + paddle.arange(0, batch_size, 1, "int32"), [batch_size]), + batch_size) - logits_aa = paddle.matmul(hidden1, hidden1_large, - transpose_y=True) / self.temperature + logits_aa = paddle.matmul( + hidden1, hidden1_large, transpose_y=True) / self.temperature logits_aa = logits_aa - masks * LARGE_NUM - logits_bb = paddle.matmul(hidden2, hidden2_large, - transpose_y=True) / self.temperature + logits_bb = paddle.matmul( + hidden2, hidden2_large, transpose_y=True) / self.temperature logits_bb = logits_bb - masks * LARGE_NUM - logits_ab = paddle.matmul(hidden1, hidden2_large, - transpose_y=True) / self.temperature - logits_ba = paddle.matmul(hidden2, hidden1_large, - transpose_y=True) / self.temperature + logits_ab = paddle.matmul( + hidden1, hidden2_large, transpose_y=True) / self.temperature + logits_ba = paddle.matmul( + hidden2, hidden1_large, transpose_y=True) / self.temperature loss_a = paddle.nn.functional.softmax_with_cross_entropy( paddle.concat([logits_ab, logits_aa], 1), labels, soft_label=True) @@ -91,10 +93,10 @@ def forward(self, pos, neg): co2_loss = 1 * (kl_1 + kl_2) total_contrast_loss = contrast_loss + 3 * co2_loss - loss = layers.reduce_mean(total_contrast_loss) + loss = paddle.mean(total_contrast_loss) contrastive_label = paddle.unsqueeze(paddle.argmax(labels, axis=1), 1) - acc1 = layers.accuracy(input=logits_ab, label=contrastive_label) + acc1 = paddle.metric.accuracy(input=logits_ab, label=contrastive_label) outputs = dict() outputs['loss'] = loss outputs['acc1'] = acc1 diff --git a/passl/modeling/necks/base_neck.py b/passl/modeling/necks/base_neck.py index 21f16255..021874a3 100644 --- a/passl/modeling/necks/base_neck.py +++ b/passl/modeling/necks/base_neck.py @@ -80,9 +80,9 @@ def __init__(self, if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, - hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) # init_backbone_weight(self.mlp) self.init_parameters() @@ -113,9 +113,12 @@ def __init__(self, if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, hid_channels, bias_attr=with_bias), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + self.mlp = nn.Sequential( + nn.Linear( + in_channels, hid_channels, bias_attr=with_bias), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear(hid_channels, out_channels)) # init_backbone_weight(self.mlp) # self.init_parameters() @@ -190,9 +193,9 @@ def __init__(self, self.conv = BottleneckBlock(in_channels, in_channels // 4) - self.mlp = nn.Sequential(nn.Linear(in_channels, - hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) init_backbone_weight(self.mlp) @@ -220,12 +223,14 @@ def __init__(self, self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, hid_channels), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, hid_channels), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels), - nn.BatchNorm1D(out_channels)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear(hid_channels, hid_channels), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear(hid_channels, out_channels), nn.BatchNorm1D(out_channels)) init_backbone_weight_simclr(self.mlp) @@ -233,9 +238,9 @@ def init_parameters(self, init_linear='normal'): _init_parameters(self, init_linear) def forward(self, x): - x = layers.squeeze(x, axes=[]) + x = paddle.squeeze(x) hidden = self.mlp(x) - hidden = layers.l2_normalize(hidden, -1) + hidden = paddle.nn.functional.normalize(hidden, axis=-1) return hidden @@ -255,13 +260,21 @@ def __init__(self, self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - self.mlp = nn.Sequential(nn.Linear(in_channels, hid_channels, bias_attr=with_bias), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, hid_channels, bias_attr=with_bias), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels, bias_attr=with_bias), - nn.BatchNorm1D(out_channels, - weight_attr=with_last_bn_affine, bias_attr=with_last_bn_affine)) + self.mlp = nn.Sequential( + nn.Linear( + in_channels, hid_channels, bias_attr=with_bias), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear( + hid_channels, hid_channels, bias_attr=with_bias), + nn.BatchNorm1D(hid_channels), + nn.ReLU(), + nn.Linear( + hid_channels, out_channels, bias_attr=with_bias), + nn.BatchNorm1D( + out_channels, + weight_attr=with_last_bn_affine, + bias_attr=with_last_bn_affine)) init_backbone_weight_simclr(self.mlp) @@ -278,6 +291,7 @@ def forward(self, x): class SwAVNeck(nn.Layer): """The non-linear neck in SwAV: fc-bn-relu-fc-normalization. """ + def __init__(self, in_channels, hid_channels, @@ -297,9 +311,8 @@ def __init__(self, else: self.projection_neck = nn.Sequential( nn.Linear(in_channels, hid_channels), - nn.BatchNorm1D(hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels) - ) + nn.BatchNorm1D(hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) def forward_projection(self, x): if self.projection_neck is not None: @@ -330,20 +343,22 @@ class MLP2d(nn.Layer): def __init__(self, in_channels, hid_channels=4096, out_channels=256): super(MLP2d, self).__init__() - self.linear1 = nn.Conv2D(in_channels, - hid_channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=True) + self.linear1 = nn.Conv2D( + in_channels, + hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) self.bn1 = nn.BatchNorm2D(hid_channels) self.relu1 = nn.ReLU() - self.linear2 = nn.Conv2D(hid_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=True) + self.linear2 = nn.Conv2D( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) self.init_parameters() def init_parameters(self, init_linear='kaiming'): @@ -363,23 +378,20 @@ def forward(self, x): class DenseCLNeck(nn.Layer): """The non-linear neck in DenseCL: fc-relu-fc, conv-relu-conv. """ - def __init__(self, - in_channels, - hid_channels, - out_channels, - num_grid=None): + + def __init__(self, in_channels, hid_channels, out_channels, num_grid=None): super(DenseCLNeck, self).__init__() self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) self.mlp = nn.Sequential( - nn.Linear(in_channels,hid_channels), nn.ReLU(), - nn.Linear(hid_channels, out_channels)) + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) self.with_pool = num_grid != None if self.with_pool: self.pool = nn.AdaptiveAvgPool2D((num_grid, num_grid)) self.mlp2 = nn.Sequential( - nn.Conv2D(in_channels, hid_channels, 1), nn.ReLU(), - nn.Conv2D(hid_channels, out_channels, 1)) + nn.Conv2D(in_channels, hid_channels, 1), + nn.ReLU(), nn.Conv2D(hid_channels, out_channels, 1)) self.avgpool2 = nn.AdaptiveAvgPool2D((1, 1)) # init_backbone_weight(self.mlp and self.mlp2)