Skip to content

Commit 1b44145

Browse files
committed
support multi-gpu training
1 parent 8404dae commit 1b44145

File tree

4 files changed

+41
-8
lines changed

4 files changed

+41
-8
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ python train.py --dataset photo2cartoon
115115
python train.py --dataset photo2cartoon --pretrained_weights models/photo2cartoon_weights.pt
116116
```
117117

118+
多GPU训练(仍建议使用batch_size=1,单卡训练):
119+
```
120+
python train.py --dataset photo2cartoon --batch_size 4 --gpu_ids 0 1 2 3
121+
```
122+
118123
## Q&A
119124
#### Q:为什么开源的卡通化模型与小程序中的效果有差异?
120125

README_EN.md

+5
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ Load pre-trained weights:
117117
python train.py --dataset photo2cartoon --pretrained_weights models/photo2cartoon_weights.pt
118118
```
119119

120+
Train with Multi-GPU:
121+
```
122+
python train.py --dataset photo2cartoon --batch_size 4 --gpu_ids 0 1 2 3
123+
```
124+
120125
## Q&A
121126
#### Q:Why is the result of this project different from mini program?
122127

models/UGATIT_sadalin_hourglass.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def __init__(self, args):
4444
self.img_size = args.img_size
4545
self.img_ch = args.img_ch
4646

47-
self.device = args.device
47+
self.device = f'cuda:{args.gpu_ids[0]}'
48+
self.gpu_ids = args.gpu_ids
4849
self.benchmark_flag = args.benchmark_flag
4950
self.resume = args.resume
5051
self.rho_clipper = args.rho_clipper
@@ -154,6 +155,14 @@ def train(self):
154155
self.disLB.load_state_dict(params['disLB'])
155156
print(" [*] Load {} Success".format(self.pretrained_weights))
156157

158+
if len(self.gpu_ids) > 1:
159+
self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids)
160+
self.genB2A = nn.DataParallel(self.genB2A, device_ids=self.gpu_ids)
161+
self.disGA = nn.DataParallel(self.disGA, device_ids=self.gpu_ids)
162+
self.disGB = nn.DataParallel(self.disGB, device_ids=self.gpu_ids)
163+
self.disLA = nn.DataParallel(self.disLA, device_ids=self.gpu_ids)
164+
self.disLB = nn.DataParallel(self.disLB, device_ids=self.gpu_ids)
165+
157166
# training loop
158167
print('training start !')
159168
start_time = time.time()
@@ -257,6 +266,9 @@ def train(self):
257266

258267
G_id_loss_A = self.facenet.cosine_distance(real_A, fake_A2B)
259268
G_id_loss_B = self.facenet.cosine_distance(real_B, fake_B2A)
269+
if len(self.gpu_ids) > 1:
270+
G_id_loss_A = torch.mean(G_id_loss_A)
271+
G_id_loss_B = torch.mean(G_id_loss_B)
260272

261273
G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + \
262274
self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
@@ -388,12 +400,22 @@ def train(self):
388400

389401
def save(self, dir, step):
390402
params = {}
391-
params['genA2B'] = self.genA2B.state_dict()
392-
params['genB2A'] = self.genB2A.state_dict()
393-
params['disGA'] = self.disGA.state_dict()
394-
params['disGB'] = self.disGB.state_dict()
395-
params['disLA'] = self.disLA.state_dict()
396-
params['disLB'] = self.disLB.state_dict()
403+
404+
if len(self.gpu_ids) > 1:
405+
params['genA2B'] = self.genA2B.module.state_dict()
406+
params['genB2A'] = self.genB2A.module.state_dict()
407+
params['disGA'] = self.disGA.module.state_dict()
408+
params['disGB'] = self.disGB.module.state_dict()
409+
params['disLA'] = self.disLA.module.state_dict()
410+
params['disLB'] = self.disLB.module.state_dict()
411+
412+
else:
413+
params['genA2B'] = self.genA2B.state_dict()
414+
params['genB2A'] = self.genB2A.state_dict()
415+
params['disGA'] = self.disGA.state_dict()
416+
params['disGB'] = self.disGB.state_dict()
417+
params['disLA'] = self.disLA.state_dict()
418+
params['disLB'] = self.disLB.state_dict()
397419
torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
398420

399421
def load(self, dir, step):

train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def parse_args():
3131
parser.add_argument('--img_size', type=int, default=256, help='The size of image')
3232
parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
3333

34-
parser.add_argument('--device', type=str, default='cuda:0', help='Set gpu mode: [cpu, cuda]')
34+
# parser.add_argument('--device', type=str, default='cuda:0', help='Set gpu mode: [cpu, cuda]')
35+
parser.add_argument('--gpu_ids', type=int, default=[0], nargs='+', help='Set [0, 1, 2, 3] for multi-gpu training')
3536
parser.add_argument('--benchmark_flag', type=str2bool, default=False)
3637
parser.add_argument('--resume', type=str2bool, default=False)
3738
parser.add_argument('--rho_clipper', type=float, default=1.0)

0 commit comments

Comments
 (0)