53
53
from diffusers .utils .torch_utils import is_compiled_module
54
54
55
55
56
+ if is_wandb_available ():
57
+ import wandb
58
+
56
59
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
57
60
check_min_version ("0.28.0.dev0" )
58
61
64
67
WANDB_TABLE_COL_NAMES = ["original_image" , "edited_image" , "edit_prompt" ]
65
68
66
69
70
+ def log_validation (
71
+ pipeline ,
72
+ args ,
73
+ accelerator ,
74
+ generator ,
75
+ ):
76
+ logger .info (
77
+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
78
+ f" { args .validation_prompt } ."
79
+ )
80
+ pipeline = pipeline .to (accelerator .device )
81
+ pipeline .set_progress_bar_config (disable = True )
82
+
83
+ # run inference
84
+ original_image = download_image (args .val_image_url )
85
+ edited_images = []
86
+ if torch .backends .mps .is_available ():
87
+ autocast_ctx = nullcontext ()
88
+ else :
89
+ autocast_ctx = torch .autocast (accelerator .device .type )
90
+
91
+ with autocast_ctx :
92
+ for _ in range (args .num_validation_images ):
93
+ edited_images .append (
94
+ pipeline (
95
+ args .validation_prompt ,
96
+ image = original_image ,
97
+ num_inference_steps = 20 ,
98
+ image_guidance_scale = 1.5 ,
99
+ guidance_scale = 7 ,
100
+ generator = generator ,
101
+ ).images [0 ]
102
+ )
103
+
104
+ for tracker in accelerator .trackers :
105
+ if tracker .name == "wandb" :
106
+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
107
+ for edited_image in edited_images :
108
+ wandb_table .add_data (wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt )
109
+ tracker .log ({"validation" : wandb_table })
110
+
111
+
67
112
def parse_args ():
68
113
parser = argparse .ArgumentParser (description = "Simple example of a training script for InstructPix2Pix." )
69
114
parser .add_argument (
@@ -411,11 +456,6 @@ def main():
411
456
412
457
generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
413
458
414
- if args .report_to == "wandb" :
415
- if not is_wandb_available ():
416
- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
417
- import wandb
418
-
419
459
# Make one log on every process with the configuration for debugging.
420
460
logging .basicConfig (
421
461
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -517,7 +557,8 @@ def save_model_hook(models, weights, output_dir):
517
557
model .save_pretrained (os .path .join (output_dir , "unet" ))
518
558
519
559
# make sure to pop weight so that corresponding model is not saved again
520
- weights .pop ()
560
+ if weights :
561
+ weights .pop ()
521
562
522
563
def load_model_hook (models , input_dir ):
523
564
if args .use_ema :
@@ -923,11 +964,6 @@ def collate_fn(examples):
923
964
and (args .validation_prompt is not None )
924
965
and (epoch % args .validation_epochs == 0 )
925
966
):
926
- logger .info (
927
- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
928
- f" { args .validation_prompt } ."
929
- )
930
- # create pipeline
931
967
if args .use_ema :
932
968
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
933
969
ema_unet .store (unet .parameters ())
@@ -942,38 +978,14 @@ def collate_fn(examples):
942
978
variant = args .variant ,
943
979
torch_dtype = weight_dtype ,
944
980
)
945
- pipeline = pipeline .to (accelerator .device )
946
- pipeline .set_progress_bar_config (disable = True )
947
-
948
- # run inference
949
- original_image = download_image (args .val_image_url )
950
- edited_images = []
951
- if torch .backends .mps .is_available ():
952
- autocast_ctx = nullcontext ()
953
- else :
954
- autocast_ctx = torch .autocast (accelerator .device .type )
955
-
956
- with autocast_ctx :
957
- for _ in range (args .num_validation_images ):
958
- edited_images .append (
959
- pipeline (
960
- args .validation_prompt ,
961
- image = original_image ,
962
- num_inference_steps = 20 ,
963
- image_guidance_scale = 1.5 ,
964
- guidance_scale = 7 ,
965
- generator = generator ,
966
- ).images [0 ]
967
- )
968
-
969
- for tracker in accelerator .trackers :
970
- if tracker .name == "wandb" :
971
- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
972
- for edited_image in edited_images :
973
- wandb_table .add_data (
974
- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
975
- )
976
- tracker .log ({"validation" : wandb_table })
981
+
982
+ log_validation (
983
+ pipeline ,
984
+ args ,
985
+ accelerator ,
986
+ generator ,
987
+ )
988
+
977
989
if args .use_ema :
978
990
# Switch back to the original UNet parameters.
979
991
ema_unet .restore (unet .parameters ())
@@ -984,15 +996,14 @@ def collate_fn(examples):
984
996
# Create the pipeline using the trained modules and save it.
985
997
accelerator .wait_for_everyone ()
986
998
if accelerator .is_main_process :
987
- unet = unwrap_model (unet )
988
999
if args .use_ema :
989
1000
ema_unet .copy_to (unet .parameters ())
990
1001
991
1002
pipeline = StableDiffusionInstructPix2PixPipeline .from_pretrained (
992
1003
args .pretrained_model_name_or_path ,
993
1004
text_encoder = unwrap_model (text_encoder ),
994
1005
vae = unwrap_model (vae ),
995
- unet = unet ,
1006
+ unet = unwrap_model ( unet ) ,
996
1007
revision = args .revision ,
997
1008
variant = args .variant ,
998
1009
)
@@ -1006,31 +1017,13 @@ def collate_fn(examples):
1006
1017
ignore_patterns = ["step_*" , "epoch_*" ],
1007
1018
)
1008
1019
1009
- if args .validation_prompt is not None :
1010
- edited_images = []
1011
- pipeline = pipeline .to (accelerator .device )
1012
- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1013
- for _ in range (args .num_validation_images ):
1014
- edited_images .append (
1015
- pipeline (
1016
- args .validation_prompt ,
1017
- image = original_image ,
1018
- num_inference_steps = 20 ,
1019
- image_guidance_scale = 1.5 ,
1020
- guidance_scale = 7 ,
1021
- generator = generator ,
1022
- ).images [0 ]
1023
- )
1024
-
1025
- for tracker in accelerator .trackers :
1026
- if tracker .name == "wandb" :
1027
- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1028
- for edited_image in edited_images :
1029
- wandb_table .add_data (
1030
- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1031
- )
1032
- tracker .log ({"test" : wandb_table })
1033
-
1020
+ if (args .val_image_url is not None ) and (args .validation_prompt is not None ):
1021
+ log_validation (
1022
+ pipeline ,
1023
+ args ,
1024
+ accelerator ,
1025
+ generator ,
1026
+ )
1034
1027
accelerator .end_training ()
1035
1028
1036
1029
0 commit comments