@@ -44,7 +44,8 @@ def __init__(self, args):
44
44
self .img_size = args .img_size
45
45
self .img_ch = args .img_ch
46
46
47
- self .device = args .device
47
+ self .device = f'cuda:{ args .gpu_ids [0 ]} '
48
+ self .gpu_ids = args .gpu_ids
48
49
self .benchmark_flag = args .benchmark_flag
49
50
self .resume = args .resume
50
51
self .rho_clipper = args .rho_clipper
@@ -154,6 +155,14 @@ def train(self):
154
155
self .disLB .load_state_dict (params ['disLB' ])
155
156
print (" [*] Load {} Success" .format (self .pretrained_weights ))
156
157
158
+ if len (self .gpu_ids ) > 1 :
159
+ self .genA2B = nn .DataParallel (self .genA2B , device_ids = self .gpu_ids )
160
+ self .genB2A = nn .DataParallel (self .genB2A , device_ids = self .gpu_ids )
161
+ self .disGA = nn .DataParallel (self .disGA , device_ids = self .gpu_ids )
162
+ self .disGB = nn .DataParallel (self .disGB , device_ids = self .gpu_ids )
163
+ self .disLA = nn .DataParallel (self .disLA , device_ids = self .gpu_ids )
164
+ self .disLB = nn .DataParallel (self .disLB , device_ids = self .gpu_ids )
165
+
157
166
# training loop
158
167
print ('training start !' )
159
168
start_time = time .time ()
@@ -257,6 +266,9 @@ def train(self):
257
266
258
267
G_id_loss_A = self .facenet .cosine_distance (real_A , fake_A2B )
259
268
G_id_loss_B = self .facenet .cosine_distance (real_B , fake_B2A )
269
+ if len (self .gpu_ids ) > 1 :
270
+ G_id_loss_A = torch .mean (G_id_loss_A )
271
+ G_id_loss_B = torch .mean (G_id_loss_B )
260
272
261
273
G_cam_loss_A = self .BCE_loss (fake_B2A_cam_logit , torch .ones_like (fake_B2A_cam_logit ).to (self .device )) + \
262
274
self .BCE_loss (fake_A2A_cam_logit , torch .zeros_like (fake_A2A_cam_logit ).to (self .device ))
@@ -388,12 +400,22 @@ def train(self):
388
400
389
401
def save (self , dir , step ):
390
402
params = {}
391
- params ['genA2B' ] = self .genA2B .state_dict ()
392
- params ['genB2A' ] = self .genB2A .state_dict ()
393
- params ['disGA' ] = self .disGA .state_dict ()
394
- params ['disGB' ] = self .disGB .state_dict ()
395
- params ['disLA' ] = self .disLA .state_dict ()
396
- params ['disLB' ] = self .disLB .state_dict ()
403
+
404
+ if len (self .gpu_ids ) > 1 :
405
+ params ['genA2B' ] = self .genA2B .module .state_dict ()
406
+ params ['genB2A' ] = self .genB2A .module .state_dict ()
407
+ params ['disGA' ] = self .disGA .module .state_dict ()
408
+ params ['disGB' ] = self .disGB .module .state_dict ()
409
+ params ['disLA' ] = self .disLA .module .state_dict ()
410
+ params ['disLB' ] = self .disLB .module .state_dict ()
411
+
412
+ else :
413
+ params ['genA2B' ] = self .genA2B .state_dict ()
414
+ params ['genB2A' ] = self .genB2A .state_dict ()
415
+ params ['disGA' ] = self .disGA .state_dict ()
416
+ params ['disGB' ] = self .disGB .state_dict ()
417
+ params ['disLA' ] = self .disLA .state_dict ()
418
+ params ['disLB' ] = self .disLB .state_dict ()
397
419
torch .save (params , os .path .join (dir , self .dataset + '_params_%07d.pt' % step ))
398
420
399
421
def load (self , dir , step ):
0 commit comments