-
Notifications
You must be signed in to change notification settings - Fork 781
Training on greyscale dataset #243
Description
Hi, is it possible to train the model on a grayscale dataset without transforming it to rgb? I tried the example code with c_dim=1 in the RRDN model but I have an error with the dimensions of the next layers:
WARNING:tensorflow:Model was constructed with shape (None, 80, 80, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 80, 80, 3), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'"), but it was called on an input with incompatible shape (None, 80, 80, 1).
Traceback (most recent call last):
File "/home/dolabok/Documents/[...]/super_resolution.py", line 76, in
trainer = Trainer(
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/ISR-2.2.0-py3.9.egg/ISR/train/trainer.py", line 105, in init
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/ISR-2.2.0-py3.9.egg/ISR/train/trainer.py", line 185, in _combine_networks
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/keras/engine/input_spec.py", line 277, in assert_input_compatibility
raise ValueError(
ValueError: Exception encountered when calling layer "discriminator" (type Functional).
Input 0 of layer "Conv_1" is incompatible with the layer: expected axis -1 of input shape to have value 3, but received input with shape (None, 80, 80, 1)
Call arguments received by layer "discriminator" (type Functional):
• inputs=tf.Tensor(shape=(None, 80, 80, 1), dtype=float32)
• training=False
• mask=None
it seems that the discriminator is hard-coded with channel = 3
My code :
# Initialize the models
lr_train_patch_size = 40
layers_to_extract = [5, 9]
scale = 2
hr_train_patch_size = lr_train_patch_size * scale
rrdn = RRDN(arch_params={'C':4, 'D':1, 'G':64, 'G0':64, 'T':10, 'x':scale}, c_dim=1, patch_size=lr_train_patch_size)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)
# Training settings
loss_weights = {
'generator': 0.0,
'feature_extractor': 0.0833,
'discriminator': 0.01
}
losses = {
'generator': 'mae',
'feature_extractor': 'mse',
'discriminator': 'binary_crossentropy'
}
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
log_dirs = {'logs': './logs', 'weights': './weights'}
# Initialize the trainer
trainer = Trainer(
generator=rrdn,
discriminator=discr,
feature_extractor=f_ext,
lr_train_dir='low_res_train',
hr_train_dir='high_res_train',
lr_valid_dir='low_res_train',
hr_valid_dir='high_res_train',
loss_weights=loss_weights,
learning_rate=learning_rate,
flatness=flatness,
dataname='image_dataset',
log_dirs=log_dirs,
weights_generator=None,
weights_discriminator=None,
n_validation=40,
)
# Train the model
trainer.train(
epochs=100,
steps_per_epoch=500,
batch_size=16,
monitored_metrics={'val_PSNR_Y': 'max'}
)