Skip to content

Commit fc23c6a

Browse files
committed
feat: D_ngf
1 parent ddcc22b commit fc23c6a

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

Diff for: models/modules/discriminators.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
input_nc,
162162
output_nc,
163163
D_num_downs,
164-
ngf=64,
164+
D_ngf=64,
165165
norm_layer=nn.BatchNorm2d,
166166
use_dropout=False,
167167
):
@@ -171,7 +171,7 @@ def __init__(
171171
output_nc (int) -- the number of channels in output images
172172
D_num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
173173
image of size 128x128 will become of size 1x1 # at the bottleneck
174-
ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512
174+
D_ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512
175175
norm_layer -- normalization layer
176176
177177
We construct the U-Net from the innermost layer to the outermost layer.
@@ -181,8 +181,8 @@ def __init__(
181181
# construct unet structure
182182
# add the innermost layer
183183
unet_block = UnetSkipConnectionBlock(
184-
ngf * 8,
185-
ngf * 8,
184+
D_ngf * 8,
185+
D_ngf * 8,
186186
input_nc=None,
187187
submodule=None,
188188
norm_layer=norm_layer,
@@ -191,28 +191,36 @@ def __init__(
191191
# add intermediate layers with ngf * 8 filters
192192
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
193193
unet_block = UnetSkipConnectionBlock(
194-
ngf * 8,
195-
ngf * 8,
194+
D_ngf * 8,
195+
D_ngf * 8,
196196
input_nc=None,
197197
submodule=unet_block,
198198
norm_layer=norm_layer,
199199
use_dropout=use_dropout,
200200
)
201201
# gradually reduce the number of filters from ngf * 8 to ngf
202202
unet_block = UnetSkipConnectionBlock(
203-
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
203+
D_ngf * 4,
204+
D_ngf * 8,
205+
input_nc=None,
206+
submodule=unet_block,
207+
norm_layer=norm_layer,
204208
)
205209
unet_block = UnetSkipConnectionBlock(
206-
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
210+
D_ngf * 2,
211+
D_ngf * 4,
212+
input_nc=None,
213+
submodule=unet_block,
214+
norm_layer=norm_layer,
207215
)
208216
unet_block = UnetSkipConnectionBlock(
209-
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
217+
D_ngf, D_ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
210218
)
211219

212220
# add the outermost layer
213221
self.model = UnetSkipConnectionBlock(
214222
output_nc,
215-
ngf,
223+
D_ngf,
216224
input_nc=input_nc,
217225
submodule=unet_block,
218226
outermost=True,

0 commit comments

Comments
 (0)