PyTorch-Lightning can't train gan the Generator gradient is 0 #12339
Unanswered
2651084156
asked this question in
code help: CV
Replies: 1 comment
-
dose anyone can help me? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
the Discriminator train is work

but when i train Generator is dosen't work
and i find is gradient is 0
i don't know the reason
and i look the examples this
https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html
but i didn't know what mistake in my code
this is my code
my code
import os import time from collections import OrderedDict from pl_bolts.models.gans import DCGAN import numpy import numpy as np import torchimport torch.nn as nn
import pytorch_lightning as pyl
from torch.autograd import Variable
from torch.autograd._functions import tensor
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from PIL import Image
class jian(nn.Module):
def init(self, ):
super().init()
self.ganj = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=300, kernel_size=(8, 8), stride=(2, 2)),
nn.BatchNorm2d(300), nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=300, out_channels=300, kernel_size=(8, 8), stride=(2, 2)),
nn.BatchNorm2d(300), nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=300, out_channels=3, kernel_size=(8, 8), stride=(2, 2)),
nn.LeakyReLU(0.2)) # 22 17
class grent(nn.Module):
def init(self,):
super().init()
self.ganout1 = nn.Sequential(nn.Linear(100, 6 * 22 * 17), nn.SELU())
# self.ganout1 = nn.Sequential(nn.Linear(1, 3 * 22 * 17), nn.LeakyReLU(0.2),nn.Linear(3 * 22 * 17, 3 * 50 * 30), nn.LeakyReLU(0.2),nn.Linear( 3 * 50 * 30, 3 * 150 * 50), nn.LeakyReLU(0.2),nn.Linear( 3 * 150 * 100, 3 * 200 * 150), nn.SELU(),nn.Linear( 3 * 200 * 150, 3 * 218 * 178), nn.Sigmoid())
# self.ganout1 = nn.Sequential(nn.Linear(10, 3 * 218 ), nn.LeakyReLU(0.2),nn.Linear(3 * 218, 3000 ), nn.LeakyReLU(0.2),nn.Linear(3000 , 3 * 218 * 178),nn.Sigmoid())
class main_modle(pyl.LightningModule):
def init(self, ):
super().init()
class dataset2(Dataset):
def init(self, csc_file):
# self.data_df= pd.read_csv(csc_file,header=None)
self.filll = os.listdir(csc_file)
# print( self.filll)
self.num =len(self.filll)
if csc_file[-1] == '/':
self.path = csc_file
else:
self.path = csc_file + '/'
aaa =dataset2(r'K:\aaaaa\Img\img_align_celeba')
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('tb_logs', name='my_model')
trainer = pyl.Trainer(gpus=1,logger =logger)
trainer.fit(model=main_modle(),train_dataloader=DataLoader(aaa,shuffle=True,batch_size=1))
Beta Was this translation helpful? Give feedback.
All reactions