PyTorch-Lightning can't train gan the Generator gradient is 0 #12339
Unanswered
2651084156
asked this question in
code help: CV
Replies: 1 comment
-
dose anyone can help me? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
the Discriminator train is work

but when i train Generator is dosen't work
and i find is gradient is 0
i don't know the reason
and i look the examples this
https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html
but i didn't know what mistake in my code
this is my code
my code
import os import time from collections import OrderedDict from pl_bolts.models.gans import DCGAN import numpy import numpy as np import torchimport torch.nn as nn
import pytorch_lightning as pyl
from torch.autograd import Variable
from torch.autograd._functions import tensor
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from PIL import Image
class jian(nn.Module):
def init(self, ):
super().init()
self.ganj = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=300, kernel_size=(8, 8), stride=(2, 2)),
nn.BatchNorm2d(300), nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=300, out_channels=300, kernel_size=(8, 8), stride=(2, 2)),
nn.BatchNorm2d(300), nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=300, out_channels=3, kernel_size=(8, 8), stride=(2, 2)),
nn.LeakyReLU(0.2)) # 22 17
class grent(nn.Module):
def init(self,):
super().init()
self.ganout1 = nn.Sequential(nn.Linear(100, 6 * 22 * 17), nn.SELU())
# self.ganout1 = nn.Sequential(nn.Linear(1, 3 * 22 * 17), nn.LeakyReLU(0.2),nn.Linear(3 * 22 * 17, 3 * 50 * 30), nn.LeakyReLU(0.2),nn.Linear( 3 * 50 * 30, 3 * 150 * 50), nn.LeakyReLU(0.2),nn.Linear( 3 * 150 * 100, 3 * 200 * 150), nn.SELU(),nn.Linear( 3 * 200 * 150, 3 * 218 * 178), nn.Sigmoid())
# self.ganout1 = nn.Sequential(nn.Linear(10, 3 * 218 ), nn.LeakyReLU(0.2),nn.Linear(3 * 218, 3000 ), nn.LeakyReLU(0.2),nn.Linear(3000 , 3 * 218 * 178),nn.Sigmoid())
class main_modle(pyl.LightningModule):
def init(self, ):
super().init()
class dataset2(Dataset):
def init(self, csc_file):
# self.data_df= pd.read_csv(csc_file,header=None)
self.filll = os.listdir(csc_file)
# print( self.filll)
self.num =len(self.filll)
if csc_file[-1] == '/':
self.path = csc_file
else:
self.path = csc_file + '/'
aaa =dataset2(r'K:\aaaaa\Img\img_align_celeba')
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('tb_logs', name='my_model')
trainer = pyl.Trainer(gpus=1,logger =logger)
trainer.fit(model=main_modle(),train_dataloader=DataLoader(aaa,shuffle=True,batch_size=1))
Beta Was this translation helpful? Give feedback.
All reactions