3
3
from helper import *
4
4
from model .generator import SkipEncoderDecoder , input_noise
5
5
6
- def remove_watermark (image_path , mask_path , max_dim , reg_noise , input_depth , lr , show_step , training_steps , tqdm_length = 100 ):
7
- DTYPE = torch .cuda .FloatTensor if torch .cuda .is_available () else torch .FloatTensor
8
- if not torch .cuda .is_available ():
9
- print ('\n Setting device to "cpu", since torch is not built with "cuda" support...' )
6
+
7
+ def remove_watermark (image_path , mask_path , max_dim , reg_noise , input_depth , lr , show_step , training_steps , tqdm_length = 100 ):
8
+ DTYPE = torch .FloatTensor
9
+ has_set_device = False
10
+ if torch .cuda .is_available ():
11
+ device = 'cuda'
12
+ has_set_device = True
13
+ print ("Setting Device to CUDA..." )
14
+ try :
15
+ if torch .backends .mps .is_available ():
16
+ device = 'mps'
17
+ has_set_device = True
18
+ print ("Setting Device to MPS..." )
19
+ except Exception as e :
20
+ print (f"Your version of pytorch might be too old, which does not support MPS. Error: \n { e } " )
21
+ pass
22
+ if not has_set_device :
23
+ device = 'cpu'
24
+ print ('\n Setting device to "cpu", since torch is not built with "cuda" or "mps" support...' )
10
25
print ('It is recommended to use GPU if possible...' )
11
26
12
27
image_np , mask_np = preprocess_images (image_path , mask_path , max_dim )
@@ -17,43 +32,43 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr,
17
32
num_channels_down = [128 ] * 5 ,
18
33
num_channels_up = [128 ] * 5 ,
19
34
num_channels_skip = [128 ] * 5
20
- ).type (DTYPE )
35
+ ).type (DTYPE ). to ( device )
21
36
22
- objective = torch .nn .MSELoss ().type (DTYPE )
37
+ objective = torch .nn .MSELoss ().type (DTYPE ). to ( device )
23
38
optimizer = optim .Adam (generator .parameters (), lr )
24
39
25
- image_var = np_to_torch_array (image_np ).type (DTYPE )
26
- mask_var = np_to_torch_array (mask_np ).type (DTYPE )
40
+ image_var = np_to_torch_array (image_np ).type (DTYPE ). to ( device )
41
+ mask_var = np_to_torch_array (mask_np ).type (DTYPE ). to ( device )
27
42
28
- generator_input = input_noise (input_depth , image_np .shape [1 :]).type (DTYPE )
43
+ generator_input = input_noise (input_depth , image_np .shape [1 :]).type (DTYPE ). to ( device )
29
44
30
45
generator_input_saved = generator_input .detach ().clone ()
31
46
noise = generator_input .detach ().clone ()
32
47
33
48
print ('\n Starting training...\n ' )
34
49
35
- progress_bar = tqdm (range (training_steps ), desc = 'Completed' , ncols = tqdm_length )
50
+ progress_bar = tqdm (range (training_steps ), desc = 'Completed' , ncols = tqdm_length )
36
51
37
52
for step in progress_bar :
38
53
optimizer .zero_grad ()
39
54
generator_input = generator_input_saved
40
55
41
56
if reg_noise > 0 :
42
57
generator_input = generator_input_saved + (noise .normal_ () * reg_noise )
43
-
58
+
44
59
output = generator (generator_input )
45
-
60
+
46
61
loss = objective (output * mask_var , image_var * mask_var )
47
62
loss .backward ()
48
63
49
64
if step % show_step == 0 :
50
65
output_image = torch_to_np_array (output )
51
66
visualize_sample (image_np , output_image , nrow = 2 , size_factor = 10 )
52
-
67
+
53
68
progress_bar .set_postfix (Loss = loss .item ())
54
-
69
+
55
70
optimizer .step ()
56
-
71
+
57
72
output_image = torch_to_np_array (output )
58
73
visualize_sample (output_image , nrow = 1 , size_factor = 10 )
59
74
@@ -62,4 +77,4 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr,
62
77
output_path = image_path .split ('/' )[- 1 ].split ('.' )[- 2 ] + '-output.jpg'
63
78
print (f'\n Saving final output image to: "{ output_path } "\n ' )
64
79
65
- pil_image .save (output_path )
80
+ pil_image .save (output_path )
0 commit comments