Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Training on greyscale dataset #243

@Dolabok

Description

@Dolabok

Hi, is it possible to train the model on a grayscale dataset without transforming it to rgb? I tried the example code with c_dim=1 in the RRDN model but I have an error with the dimensions of the next layers:

WARNING:tensorflow:Model was constructed with shape (None, 80, 80, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 80, 80, 3), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'"), but it was called on an input with incompatible shape (None, 80, 80, 1).
Traceback (most recent call last):
File "/home/dolabok/Documents/[...]/super_resolution.py", line 76, in
trainer = Trainer(
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/ISR-2.2.0-py3.9.egg/ISR/train/trainer.py", line 105, in init
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/ISR-2.2.0-py3.9.egg/ISR/train/trainer.py", line 185, in _combine_networks
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/dolabok/anaconda3/lib/python3.9/site-packages/keras/engine/input_spec.py", line 277, in assert_input_compatibility
raise ValueError(
ValueError: Exception encountered when calling layer "discriminator" (type Functional).

Input 0 of layer "Conv_1" is incompatible with the layer: expected axis -1 of input shape to have value 3, but received input with shape (None, 80, 80, 1)

Call arguments received by layer "discriminator" (type Functional):
• inputs=tf.Tensor(shape=(None, 80, 80, 1), dtype=float32)
• training=False
• mask=None

it seems that the discriminator is hard-coded with channel = 3

My code :

# Initialize the models
lr_train_patch_size = 40
layers_to_extract = [5, 9]
scale = 2
hr_train_patch_size = lr_train_patch_size * scale

rrdn = RRDN(arch_params={'C':4, 'D':1, 'G':64, 'G0':64, 'T':10, 'x':scale}, c_dim=1, patch_size=lr_train_patch_size)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

# Training settings
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.0833,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
}

learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
log_dirs = {'logs': './logs', 'weights': './weights'}

# Initialize the trainer
trainer = Trainer(
    generator=rrdn,
    discriminator=discr,
    feature_extractor=f_ext,
    lr_train_dir='low_res_train',
    hr_train_dir='high_res_train',
    lr_valid_dir='low_res_train',
    hr_valid_dir='high_res_train',
    loss_weights=loss_weights,
    learning_rate=learning_rate,
    flatness=flatness,
    dataname='image_dataset',
    log_dirs=log_dirs,
    weights_generator=None,
    weights_discriminator=None,
    n_validation=40,
)
    
# Train the model
trainer.train(
    epochs=100,
    steps_per_epoch=500,
    batch_size=16,
    monitored_metrics={'val_PSNR_Y': 'max'}
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions