diff --git a/bin/gen_mask_dataset.py b/bin/gen_mask_dataset.py index 6e2ce3a9..8b4af943 100755 --- a/bin/gen_mask_dataset.py +++ b/bin/gen_mask_dataset.py @@ -19,9 +19,9 @@ def __init__(self, impl, variants_n=2): self.impl = impl self.variants_n = variants_n - def get_masks(self, img): + def get_masks(self, img, indir= None): img = np.transpose(np.array(img), (2, 0, 1)) - return [self.impl(img)[0] for _ in range(self.variants_n)] + return [self.impl(img,indir=indir)[0] for _ in range(self.variants_n)] def process_images(src_images, indir, outdir, config): @@ -29,8 +29,8 @@ def process_images(src_images, indir, outdir, config): mask_generator = SegmentationMask(**config.mask_generator_kwargs) elif config.generator_kind == 'random': variants_n = config.mask_generator_kwargs.pop('variants_n', 2) - mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs), - variants_n=variants_n) + mixed_mask_generator = MixedMaskGenerator(**config.mask_generator_kwargs) + mask_generator = MakeManyMasksWrapper(mixed_mask_generator, variants_n=variants_n) else: raise ValueError(f'Unexpected generator kind: {config.generator_kind}') @@ -59,7 +59,7 @@ def process_images(src_images, indir, outdir, config): image = image.resize(out_size, resample=Image.BICUBIC) # generate and select masks - src_masks = mask_generator.get_masks(image) + src_masks = mask_generator.get_masks(image,indir) filtered_image_mask_pairs = [] for cur_mask in src_masks: @@ -104,7 +104,13 @@ def main(args): os.makedirs(args.outdir, exist_ok=True) config = load_yaml(args.config) - + if args.occ_indir: + if "occ_mask_indir" in config.mask_generator_kwargs: + config.mask_generator_kwargs["occ_mask_indir"]= args.occ_indir + else: + print("ERROR | Trying to generate using occlusion masks but the config file does not contain the path to them") + + print("DEBUG",config) in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True)) if args.n_jobs == 0: process_images(in_files, args.indir, args.outdir, config) @@ -124,6 +130,7 @@ def main(args): aparser.add_argument('config', type=str, help='Path to config for dataset generation') aparser.add_argument('indir', type=str, help='Path to folder with images') aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to') + aparser.add_argument('--occ_indir', type=str,default=None, help ='Path to the occlusion folder') aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use') aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension') diff --git a/configs/data_gen/random_thin_occ_512.yaml b/configs/data_gen/random_thin_occ_512.yaml new file mode 100644 index 00000000..ea027295 --- /dev/null +++ b/configs/data_gen/random_thin_occ_512.yaml @@ -0,0 +1,28 @@ +generator_kind: random + +mask_generator_kwargs: + irregular_proba: 1 + irregular_kwargs: + min_times: 4 + max_times: 50 + max_width: 10 + max_angle: 4 + max_len: 40 + box_proba: 0 + segm_proba: 0 + squares_proba: 0 + + occ_mask: True + occ_mask_indir: ${training.location.occ_mask_root_dir} #overwrite when running gen_mask_dataset.py + + variants_n: 5 + +max_masks_per_image: 1 + +cropping: + out_min_size: 256 + handle_small_mode: upscale + out_square_crop: True + crop_min_overlap: 1 + +max_tamper_area: 0.5 diff --git a/configs/training/data/abl-04-256-mh-dist.yaml b/configs/training/data/abl-04-256-mh-dist.yaml index 203e6aa0..aacaeee4 100644 --- a/configs/training/data/abl-04-256-mh-dist.yaml +++ b/configs/training/data/abl-04-256-mh-dist.yaml @@ -1,8 +1,8 @@ # @package _group_ -batch_size: 10 +batch_size: 8 val_batch_size: 2 -num_workers: 3 +num_workers: 3 train: indir: ${location.data_root_dir}/train @@ -11,18 +11,21 @@ train: irregular_proba: 1 irregular_kwargs: max_angle: 4 - max_len: 200 - max_width: 100 - max_times: 5 + max_len: 100 + max_width: 20 + max_times: 3 min_times: 1 - box_proba: 1 - box_kwargs: - margin: 10 - bbox_min_size: 30 - bbox_max_size: 150 - max_times: 4 - min_times: 1 + # box_proba: 1 + # box_kwargs: + # margin: 10 + # bbox_min_size: 30 + # bbox_max_size: 150 + # max_times: 4 + # min_times: 1 + + occ_mask: True + occ_mask_indir: ${location.occ_mask_root_dir}/train segm_proba: 0 diff --git a/configs/training/generator/ffc_resnet_075.yaml b/configs/training/generator/ffc_resnet_075.yaml index 0bac88f9..0b695906 100644 --- a/configs/training/generator/ffc_resnet_075.yaml +++ b/configs/training/generator/ffc_resnet_075.yaml @@ -6,6 +6,7 @@ ngf: 64 n_downsampling: 3 n_blocks: 9 add_out_act: sigmoid +conv_kind: depthwise init_conv_kwargs: ratio_gin: 0 diff --git a/configs/training/lama-fourier.yaml b/configs/training/lama-fourier.yaml index 0c8d3a92..3b8c373b 100644 --- a/configs/training/lama-fourier.yaml +++ b/configs/training/lama-fourier.yaml @@ -5,6 +5,7 @@ training_model: visualize_each_iters: 1000 concat_mask: true store_discr_outputs_for_vis: true + losses: l1: weight_missing: 0 diff --git a/configs/training/location/places_standard.yaml b/configs/training/location/places_standard.yaml new file mode 100644 index 00000000..e08a55f9 --- /dev/null +++ b/configs/training/location/places_standard.yaml @@ -0,0 +1,6 @@ +# @package _group_ +data_root_dir: /home/isaacfs/datasets/places_standard_dataset/ +occ_mask_root_dir: /home/isaacfs/occlusions_mask/ +out_root_dir: /home/isaacfs/lama/experiments/ +tb_dir: /home/isaacfs/lama/tb_logs/ +pretrained_models: /home/isaacfs/lama/ diff --git a/configs/training/trainer/any_gpu_large_ssim_ddp_final.yaml b/configs/training/trainer/any_gpu_large_ssim_ddp_final.yaml index 5da9ed3f..37cb0ec8 100644 --- a/configs/training/trainer/any_gpu_large_ssim_ddp_final.yaml +++ b/configs/training/trainer/any_gpu_large_ssim_ddp_final.yaml @@ -2,10 +2,10 @@ kwargs: gpus: -1 accelerator: ddp - max_epochs: 40 + max_epochs: 20 gradient_clip_val: 1 log_gpu_memory: None # set to min_max or all for debug - limit_train_batches: 25000 + limit_train_batches: 10 val_check_interval: ${trainer.kwargs.limit_train_batches} # fast_dev_run: True # uncomment for faster debug # track_grad_norm: 2 # uncomment to track L2 gradients norm @@ -22,10 +22,14 @@ kwargs: # limit_val_batches: 1000000 replace_sampler_ddp: False +logs: + log_on_epoch: true + log_on_step: false + checkpoint_kwargs: verbose: True save_top_k: 5 save_last: True period: 1 monitor: val_ssim_fid100_f1_total_mean - mode: max \ No newline at end of file + mode: max diff --git a/fetch_data/places_standard_test_val_gen_masks.sh b/fetch_data/places_standard_test_val_gen_masks.sh index 46547797..8160c20b 100755 --- a/fetch_data/places_standard_test_val_gen_masks.sh +++ b/fetch_data/places_standard_test_val_gen_masks.sh @@ -3,11 +3,13 @@ mkdir -p places_standard_dataset/visual_test/ python3 bin/gen_mask_dataset.py \ -$(pwd)/configs/data_gen/random_thick_512.yaml \ -places_standard_dataset/val_hires/ \ -places_standard_dataset/val/ +$(pwd)/configs/data_gen/random_thin_occ_512.yaml \ +/home/isaacfs/places_standard_dataset/val_hires/ \ +/home/isaacfs/places_standard_dataset/val/ \ +--occ_indir /home/isaacfs/occlusions_mask/original/test/test_large/ python3 bin/gen_mask_dataset.py \ -$(pwd)/configs/data_gen/random_thick_512.yaml \ -places_standard_dataset/visual_test_hires/ \ -places_standard_dataset/visual_test/ \ No newline at end of file +$(pwd)/configs/data_gen/random_thin_occ_512.yaml \ +/home/isaacfsplaces_standard_dataset/visual_test_hires/ \ +/home/isaacfs/places_standard_dataset/visual_test/ \ +--occ_indir /home/isaacfs/occlusions_mask/original/val/val_large/ \ No newline at end of file diff --git a/fetch_data/places_standard_train_prepare.sh b/fetch_data/places_standard_train_prepare.sh index aaf42924..9ee7b018 100644 --- a/fetch_data/places_standard_train_prepare.sh +++ b/fetch_data/places_standard_train_prepare.sh @@ -1,7 +1,7 @@ -mkdir -p places_standard_dataset/train +#mkdir -p places_standard_dataset/train # untar without folder structure -tar -xvf train_large_places365standard.tar -C places_standard_dataset/train +#tar -xvf train_large_places365standard.tar -C places_standard_dataset/train # create location config places.yaml PWD=$(pwd) diff --git a/fetch_data/sampler.py b/fetch_data/sampler.py index 9bf48b7f..5f6a2e7c 100644 --- a/fetch_data/sampler.py +++ b/fetch_data/sampler.py @@ -1,7 +1,7 @@ import os import random -test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/' +test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/test_large/' list_of_random_test_files = os.path.abspath('.') + '/places_standard_dataset/original/test_random_files.txt' test_files = [ @@ -22,7 +22,7 @@ # -------------------------------- -val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/' +val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/val_large/' list_of_random_val_files = os.path.abspath('.') + '/places_standard_dataset/original/val_random_files.txt' val_files = [ diff --git a/requirements.txt b/requirements.txt index d412392c..a057469e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ pyyaml tqdm numpy easydict==1.9.0 -scikit-image==0.17.2 -scikit-learn==0.24.2 +scikit-image +scikit-learn opencv-python tensorflow joblib @@ -16,5 +16,5 @@ tabulate kornia==0.5.0 webdataset packaging -scikit-learn==0.24.2 wldhx.yadisk-direct + diff --git a/saicinpainting/training/data/datasets.py b/saicinpainting/training/data/datasets.py index c4f503da..9e455284 100644 --- a/saicinpainting/training/data/datasets.py +++ b/saicinpainting/training/data/datasets.py @@ -28,6 +28,7 @@ def __init__(self, indir, mask_generator, transform): self.mask_generator = mask_generator self.transform = transform self.iter_i = 0 + self.indir = indir def __len__(self): return len(self.in_files) @@ -39,7 +40,7 @@ def __getitem__(self, item): img = self.transform(image=img)['image'] img = np.transpose(img, (2, 0, 1)) # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks - mask = self.mask_generator(img, iter_i=self.iter_i) + mask = self.mask_generator(img, path=path, iter_i=self.iter_i, indir=self.indir) self.iter_i += 1 return dict(image=img, mask=mask) diff --git a/saicinpainting/training/data/masks.py b/saicinpainting/training/data/masks.py index e91fc749..dbdbb82d 100644 --- a/saicinpainting/training/data/masks.py +++ b/saicinpainting/training/data/masks.py @@ -3,9 +3,9 @@ import hashlib import logging from enum import Enum - import cv2 import numpy as np +import os from saicinpainting.evaluation.masks.mask import SegmentationMask from saicinpainting.utils import LinearRamp @@ -256,9 +256,12 @@ def __init__(self, irregular_proba=1/3, irregular_kwargs=None, squares_proba=0, squares_kwargs=None, superres_proba=0, superres_kwargs=None, outpainting_proba=0, outpainting_kwargs=None, + occ_mask=False, occ_mask_indir=None, invert_proba=0): self.probas = [] self.gens = [] + self.occ_mask = occ_mask + self.occ_mask_indir = occ_mask_indir if irregular_proba > 0: self.probas.append(irregular_proba) @@ -306,12 +309,37 @@ def __init__(self, irregular_proba=1/3, irregular_kwargs=None, self.probas /= self.probas.sum() self.invert_proba = invert_proba - def __call__(self, img, iter_i=None, raw_image=None): + def __call__(self, img, path=None, iter_i=None, raw_image=None, indir=None): kind = np.random.choice(len(self.probas), p=self.probas) gen = self.gens[kind] result = gen(img, iter_i=iter_i, raw_image=raw_image) if self.invert_proba > 0 and random.random() < self.invert_proba: result = 1 - result + + # Training for parallax tasks + if self.occ_mask: + + if path is None: + raise Exception("Trying to use occlusion mask but no path is provided!\nTroubleshoot-Idea: check the dataset call for the mask generation function") + + # Deriving the occlusion mask path from the image path + filename = os.path.basename(path) + occ_mask_path = path.replace(indir,self.occ_mask_indir) + occ_mask_path = occ_mask_path.replace(filename, "") + occ_mask_path = os.path.join(occ_mask_path, f"occlusion_{filename}") + + occ_mask = cv2.imread(occ_mask_path) + occ_mask = cv2.cvtColor(occ_mask, cv2.COLOR_BGR2GRAY) + + occ_mask = np.expand_dims(occ_mask, axis=0) + # Convert mask2 to 0 and 1 + occ_mask = occ_mask / np.max(occ_mask) + occ_mask = (occ_mask > 0).astype('float32') + + # Blend the masks + result = np.logical_or(result, occ_mask).astype(result.dtype) + + return result diff --git a/saicinpainting/training/modules/ffc.py b/saicinpainting/training/modules/ffc.py index 2f8aeb14..0761afea 100644 --- a/saicinpainting/training/modules/ffc.py +++ b/saicinpainting/training/modules/ffc.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F -from saicinpainting.training.modules.base import get_activation, BaseDiscriminator +from saicinpainting.training.modules.base import get_conv_block_ctor, get_activation, BaseDiscriminator from saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper from saicinpainting.training.modules.squeeze_excitation import SELayer from saicinpainting.utils import get_shape @@ -49,13 +49,16 @@ def forward(self, x): class FourierUnit(nn.Module): def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', - spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho', conv_kind='default'): # bn_layer not used super(FourierUnit, self).__init__() self.groups = groups - self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), - out_channels=out_channels * 2, + #! addition + conv_layer = get_conv_block_ctor(conv_kind) + + self.conv_layer = conv_layer(in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels * 2, kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) self.bn = torch.nn.BatchNorm2d(out_channels * 2) self.relu = torch.nn.ReLU(inplace=True) @@ -115,7 +118,7 @@ def forward(self, x): class SpectralTransform(nn.Module): - def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): + def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, conv_kind='default', **fu_kwargs): # bn_layer not used super(SpectralTransform, self).__init__() self.enable_lfu = enable_lfu @@ -125,18 +128,22 @@ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=Tru self.downsample = nn.Identity() self.stride = stride + + #! addition + conv_layer = get_conv_block_ctor(conv_kind) + self.conv1 = nn.Sequential( - nn.Conv2d(in_channels, out_channels // + conv_layer(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False), nn.BatchNorm2d(out_channels // 2), nn.ReLU(inplace=True) ) self.fu = FourierUnit( - out_channels // 2, out_channels // 2, groups, **fu_kwargs) + out_channels // 2, out_channels // 2, groups, conv_kind=conv_kind, **fu_kwargs) if self.enable_lfu: self.lfu = FourierUnit( - out_channels // 2, out_channels // 2, groups) - self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels // 2, groups, conv_kind=conv_kind) + self.conv2 = conv_layer( out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) def forward(self, x): @@ -168,7 +175,7 @@ class FFC(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0, dilation=1, groups=1, bias=False, enable_lfu=True, - padding_type='reflect', gated=False, **spectral_kwargs): + padding_type='reflect', conv_kind='default', gated=False, **spectral_kwargs): super(FFC, self).__init__() assert stride == 1 or stride == 2, "Stride should be 1 or 2." @@ -185,21 +192,24 @@ def __init__(self, in_channels, out_channels, kernel_size, self.ratio_gout = ratio_gout self.global_in_num = in_cg - module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + #! addition + conv_layer = get_conv_block_ctor(conv_kind) + + module = nn.Identity if in_cl == 0 or out_cl == 0 else conv_layer self.convl2l = module(in_cl, out_cl, kernel_size, - stride, padding, dilation, groups, bias, padding_mode=padding_type) - module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + stride, padding, dilation, bias=bias, groups=groups, padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else conv_layer self.convl2g = module(in_cl, out_cg, kernel_size, - stride, padding, dilation, groups, bias, padding_mode=padding_type) - module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + stride, padding, dilation, bias=bias, groups=groups, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else conv_layer self.convg2l = module(in_cg, out_cl, kernel_size, - stride, padding, dilation, groups, bias, padding_mode=padding_type) + stride, padding, dilation, bias=bias, groups=groups, padding_mode=padding_type) module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform self.convg2g = module( - in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, conv_kind=conv_kind, **spectral_kwargs) self.gated = gated - module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else conv_layer self.gate = module(in_channels, 2, 1) def forward(self, x): @@ -231,12 +241,12 @@ def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0, dilation=1, groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, - padding_type='reflect', + padding_type='reflect', conv_kind='default', enable_lfu=True, **kwargs): super(FFC_BN_ACT, self).__init__() self.ffc = FFC(in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, - groups, bias, enable_lfu, padding_type=padding_type, **kwargs) + groups, bias, enable_lfu, padding_type=padding_type, conv_kind=conv_kind, **kwargs) lnorm = nn.Identity if ratio_gout == 1 else norm_layer gnorm = nn.Identity if ratio_gout == 0 else norm_layer global_channels = int(out_channels * ratio_gout) @@ -256,18 +266,18 @@ def forward(self, x): class FFCResnetBlock(nn.Module): - def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, + def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, conv_kind='default', spatial_transform_kwargs=None, inline=False, **conv_kwargs): super().__init__() self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, norm_layer=norm_layer, activation_layer=activation_layer, - padding_type=padding_type, + padding_type=padding_type, conv_kind=conv_kind, **conv_kwargs) self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, norm_layer=norm_layer, activation_layer=activation_layer, - padding_type=padding_type, + padding_type=padding_type, conv_kind=conv_kind, **conv_kwargs) if spatial_transform_kwargs is not None: self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs) @@ -303,7 +313,7 @@ def forward(self, x): class FFCResNetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, conv_kind='default', padding_type='reflect', activation_layer=nn.ReLU, up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={}, @@ -311,10 +321,11 @@ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, no add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): assert (n_blocks >= 0) super().__init__() - + + print(f"DEBUG | convolution type: {conv_kind}") model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer, - activation_layer=activation_layer, **init_conv_kwargs)] + activation_layer=activation_layer, conv_kind=conv_kind, **init_conv_kwargs)] ### downsample for i in range(n_downsampling): @@ -329,6 +340,7 @@ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, no kernel_size=3, stride=2, padding=1, norm_layer=norm_layer, activation_layer=activation_layer, + conv_kind=conv_kind, **cur_conv_kwargs)] mult = 2 ** n_downsampling @@ -337,7 +349,7 @@ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, no ### resnet blocks for i in range(n_blocks): cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer, - norm_layer=norm_layer, **resnet_conv_kwargs) + norm_layer=norm_layer, conv_kind=conv_kind, **resnet_conv_kwargs) if spatial_transform_layers is not None and i in spatial_transform_layers: cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs) model += [cur_resblock] @@ -355,7 +367,7 @@ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, no if out_ffc: model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer, - norm_layer=norm_layer, inline=True, **out_ffc_kwargs)] + norm_layer=norm_layer, conv_kind=conv_kind, inline=True, **out_ffc_kwargs)] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] diff --git a/saicinpainting/training/trainers/base.py b/saicinpainting/training/trainers/base.py index f1b1c66f..cd35b0c6 100644 --- a/saicinpainting/training/trainers/base.py +++ b/saicinpainting/training/trainers/base.py @@ -78,6 +78,9 @@ def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_i self.val_evaluator = make_evaluator(**self.config.evaluator) self.test_evaluator = make_evaluator(**self.config.evaluator) + self.log_on_epoch = config.trainer.logs.log_on_epoch + self.log_on_step = config.trainer.logs.log_on_step + if not get_has_ddp_rank(): LOGGER.info(f'Discriminator\n{self.discriminator}') @@ -174,7 +177,9 @@ def training_step_end(self, batch_parts_outputs): if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True)) log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()} - self.log_dict(log_info, on_step=True, on_epoch=False) + + self.log_dict(log_info, on_step=self.log_on_step, on_epoch=self.log_on_epoch) + return full_loss def validation_epoch_end(self, outputs): @@ -194,7 +199,10 @@ def validation_epoch_end(self, outputs): f'total {self.global_step} iterations:\n{val_evaluator_res_df}') for k, v in flatten_dict(val_evaluator_res).items(): - self.log(f'val_{k}', v) + metric_name = f'val_{k}' + if 'total' not in metric_name: + metric_name = '_' + metric_name + self.log(metric_name, v) # standard visual test test_evaluator_states = [s['test_evaluator_state'] for s in outputs @@ -206,7 +214,11 @@ def validation_epoch_end(self, outputs): f'total {self.global_step} iterations:\n{test_evaluator_res_df}') for k, v in flatten_dict(test_evaluator_res).items(): - self.log(f'test_{k}', v) + metric_name = f'test_{k}' + if 'total' not in metric_name: + metric_name = '_' + metric_name + self.log(metric_name, v) + # extra validations if self.extra_evaluators: