Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

ISR Traininig error Colab #253

@Leprechault

Description

@Leprechault

I tried to use the Colab ISR Traininig tutorial.ipynb, despite the modification in the first line !pip install ISR to !pip install ISR --no-deps and install tensorflow before (!pip install tensorflow), and change nothing more, just only replicate the example in step Give the models to the Trainer:

from ISR.train import Trainer
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.0833,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
} 

log_dirs = {'logs': './logs', 'weights': './weights'}

learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}

flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

trainer = Trainer(
    generator=rrdn,
    discriminator=discr,
    feature_extractor=f_ext,
    lr_train_dir='div2k/DIV2K_train_LR_bicubic/X2/',
    hr_train_dir='div2k/DIV2K_train_HR/',
    lr_valid_dir='div2k/DIV2K_train_LR_bicubic/X2/',
    hr_valid_dir='div2k/DIV2K_train_HR/',
    loss_weights=loss_weights,
    learning_rate=learning_rate,
    flatness=flatness,
    dataname='div2k',
    log_dirs=log_dirs,
    weights_generator=None,
    weights_discriminator=None,
    n_validation=40,
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-6-5b1979e121e0>](https://localhost:8080/#) in <cell line: 19>()
     17 flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
     18 
---> 19 trainer = Trainer(
     20     generator=rrdn,
     21     discriminator=discr,

4 frames
[/usr/local/lib/python3.10/dist-packages/ISR/train/trainer.py](https://localhost:8080/#) in __init__(self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights, log_dirs, fallback_save_every_n_epochs, dataname, weights_generator, weights_discriminator, n_validation, flatness, learning_rate, adam_optimizer, losses, metrics)
    103             self.metrics['generator'] = PSNR
    104         self._parameters_sanity_check()
--> 105         self.model = self._combine_networks()
    106 
    107         self.settings = {}

[/usr/local/lib/python3.10/dist-packages/ISR/train/trainer.py](https://localhost:8080/#) in _combine_networks(self)
    197         combined = Model(inputs=lr, outputs=outputs)
    198         # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows
--> 199         optimizer = Adam(
    200             beta_1=self.adam_optimizer['beta1'],
    201             beta_2=self.adam_optimizer['beta2'],

[/usr/local/lib/python3.10/dist-packages/keras/src/optimizers/adam.py](https://localhost:8080/#) in __init__(self, learning_rate, beta_1, beta_2, epsilon, amsgrad, weight_decay, clipnorm, clipvalue, global_clipnorm, use_ema, ema_momentum, ema_overwrite_frequency, loss_scale_factor, gradient_accumulation_steps, name, **kwargs)
     60         **kwargs,
     61     ):
---> 62         super().__init__(
     63             learning_rate=learning_rate,
     64             name=name,

[/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/optimizer.py](https://localhost:8080/#) in __init__(self, *args, **kwargs)
     20 
     21     def __init__(self, *args, **kwargs):
---> 22         super().__init__(*args, **kwargs)
     23         self._distribution_strategy = tf.distribute.get_strategy()
     24 

[/usr/local/lib/python3.10/dist-packages/keras/src/optimizers/base_optimizer.py](https://localhost:8080/#) in __init__(self, learning_rate, weight_decay, clipnorm, clipvalue, global_clipnorm, use_ema, ema_momentum, ema_overwrite_frequency, loss_scale_factor, gradient_accumulation_steps, name, **kwargs)
     35             )
     36         if kwargs:
---> 37             raise ValueError(f"Argument(s) not recognized: {kwargs}")
     38 
     39         if name is None:

ValueError: Argument(s) not recognized: {'lr': 0.0004}

Many errors occur. Please, any help with it?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions