Skip to content

Commit 47187f7

Browse files
authored
Merge pull request kohya-ss#1285 from ccharest93/main
Hyperparameter tracking
2 parents e3ddd1f + b886d0a commit 47187f7

10 files changed

+36
-9
lines changed

fine_tune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
310310
init_kwargs["wandb"] = {"name": args.wandb_run_name}
311311
if args.log_tracker_config is not None:
312312
init_kwargs = toml.load(args.log_tracker_config)
313-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
313+
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
314314

315315
# For --sample_at_first
316316
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

library/train_util.py

+27
Original file line numberDiff line numberDiff line change
@@ -3388,6 +3388,33 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
33883388
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
33893389
)
33903390

3391+
def filter_sensitive_args(args: argparse.Namespace):
3392+
sensitive_args = ["wandb_api_key", "huggingface_token"]
3393+
sensitive_path_args = [
3394+
"pretrained_model_name_or_path",
3395+
"vae",
3396+
"tokenizer_cache_dir",
3397+
"train_data_dir",
3398+
"conditioning_data_dir",
3399+
"reg_data_dir",
3400+
"output_dir",
3401+
"logging_dir",
3402+
]
3403+
filtered_args = {}
3404+
for k, v in vars(args).items():
3405+
# filter out sensitive values
3406+
if k not in sensitive_args + sensitive_path_args:
3407+
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
3408+
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
3409+
filtered_args[k] = v
3410+
# accelerate does not support lists
3411+
elif isinstance(v, list):
3412+
filtered_args[k] = f"{v}"
3413+
# accelerate does not support objects
3414+
elif isinstance(v, object):
3415+
filtered_args[k] = f"{v}"
3416+
3417+
return filtered_args
33913418

33923419
# verify command line args for training
33933420
def verify_command_line_training_args(args: argparse.Namespace):

sdxl_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def optimizer_hook(parameter: torch.Tensor):
589589
init_kwargs["wandb"] = {"name": args.wandb_run_name}
590590
if args.log_tracker_config is not None:
591591
init_kwargs = toml.load(args.log_tracker_config)
592-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
592+
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
593593

594594
# For --sample_at_first
595595
sdxl_train_util.sample_images(

sdxl_train_control_net_lllite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def train(args):
354354
if args.log_tracker_config is not None:
355355
init_kwargs = toml.load(args.log_tracker_config)
356356
accelerator.init_trackers(
357-
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
357+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
358358
)
359359

360360
loss_recorder = train_util.LossRecorder()

sdxl_train_control_net_lllite_old.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def train(args):
324324
if args.log_tracker_config is not None:
325325
init_kwargs = toml.load(args.log_tracker_config)
326326
accelerator.init_trackers(
327-
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
327+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
328328
)
329329

330330
loss_recorder = train_util.LossRecorder()

train_controlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def train(args):
344344
if args.log_tracker_config is not None:
345345
init_kwargs = toml.load(args.log_tracker_config)
346346
accelerator.init_trackers(
347-
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
347+
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
348348
)
349349

350350
loss_recorder = train_util.LossRecorder()

train_db.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def train(args):
290290
init_kwargs["wandb"] = {"name": args.wandb_run_name}
291291
if args.log_tracker_config is not None:
292292
init_kwargs = toml.load(args.log_tracker_config)
293-
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
293+
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
294294

295295
# For --sample_at_first
296296
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def load_model_hook(models, input_dir):
774774
if args.log_tracker_config is not None:
775775
init_kwargs = toml.load(args.log_tracker_config)
776776
accelerator.init_trackers(
777-
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
777+
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
778778
)
779779

780780
loss_recorder = train_util.LossRecorder()

train_textual_inversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def train(self, args):
510510
if args.log_tracker_config is not None:
511511
init_kwargs = toml.load(args.log_tracker_config)
512512
accelerator.init_trackers(
513-
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
513+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
514514
)
515515

516516
# function for saving/removing

train_textual_inversion_XTI.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def train(args):
407407
if args.log_tracker_config is not None:
408408
init_kwargs = toml.load(args.log_tracker_config)
409409
accelerator.init_trackers(
410-
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
410+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
411411
)
412412

413413
# function for saving/removing

0 commit comments

Comments
 (0)