Skip to content

Commit da8d2fb

Browse files
satani99sayakpaul
andcommitted
Modularize instruct_pix2pix SD inferencing during and after training in examples (#7603)
* Modularize instruct_pix2pix code * quality check * quality check --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 4dc77c2 commit da8d2fb

File tree

1 file changed

+63
-70
lines changed

1 file changed

+63
-70
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

+63-70
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
from diffusers.utils.torch_utils import is_compiled_module
5454

5555

56+
if is_wandb_available():
57+
import wandb
58+
5659
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5760
check_min_version("0.28.0.dev0")
5861

@@ -64,6 +67,48 @@
6467
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
6568

6669

70+
def log_validation(
71+
pipeline,
72+
args,
73+
accelerator,
74+
generator,
75+
):
76+
logger.info(
77+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
78+
f" {args.validation_prompt}."
79+
)
80+
pipeline = pipeline.to(accelerator.device)
81+
pipeline.set_progress_bar_config(disable=True)
82+
83+
# run inference
84+
original_image = download_image(args.val_image_url)
85+
edited_images = []
86+
if torch.backends.mps.is_available():
87+
autocast_ctx = nullcontext()
88+
else:
89+
autocast_ctx = torch.autocast(accelerator.device.type)
90+
91+
with autocast_ctx:
92+
for _ in range(args.num_validation_images):
93+
edited_images.append(
94+
pipeline(
95+
args.validation_prompt,
96+
image=original_image,
97+
num_inference_steps=20,
98+
image_guidance_scale=1.5,
99+
guidance_scale=7,
100+
generator=generator,
101+
).images[0]
102+
)
103+
104+
for tracker in accelerator.trackers:
105+
if tracker.name == "wandb":
106+
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
107+
for edited_image in edited_images:
108+
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
109+
tracker.log({"validation": wandb_table})
110+
111+
67112
def parse_args():
68113
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
69114
parser.add_argument(
@@ -411,11 +456,6 @@ def main():
411456

412457
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
413458

414-
if args.report_to == "wandb":
415-
if not is_wandb_available():
416-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
417-
import wandb
418-
419459
# Make one log on every process with the configuration for debugging.
420460
logging.basicConfig(
421461
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -517,7 +557,8 @@ def save_model_hook(models, weights, output_dir):
517557
model.save_pretrained(os.path.join(output_dir, "unet"))
518558

519559
# make sure to pop weight so that corresponding model is not saved again
520-
weights.pop()
560+
if weights:
561+
weights.pop()
521562

522563
def load_model_hook(models, input_dir):
523564
if args.use_ema:
@@ -923,11 +964,6 @@ def collate_fn(examples):
923964
and (args.validation_prompt is not None)
924965
and (epoch % args.validation_epochs == 0)
925966
):
926-
logger.info(
927-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
928-
f" {args.validation_prompt}."
929-
)
930-
# create pipeline
931967
if args.use_ema:
932968
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
933969
ema_unet.store(unet.parameters())
@@ -942,38 +978,14 @@ def collate_fn(examples):
942978
variant=args.variant,
943979
torch_dtype=weight_dtype,
944980
)
945-
pipeline = pipeline.to(accelerator.device)
946-
pipeline.set_progress_bar_config(disable=True)
947-
948-
# run inference
949-
original_image = download_image(args.val_image_url)
950-
edited_images = []
951-
if torch.backends.mps.is_available():
952-
autocast_ctx = nullcontext()
953-
else:
954-
autocast_ctx = torch.autocast(accelerator.device.type)
955-
956-
with autocast_ctx:
957-
for _ in range(args.num_validation_images):
958-
edited_images.append(
959-
pipeline(
960-
args.validation_prompt,
961-
image=original_image,
962-
num_inference_steps=20,
963-
image_guidance_scale=1.5,
964-
guidance_scale=7,
965-
generator=generator,
966-
).images[0]
967-
)
968-
969-
for tracker in accelerator.trackers:
970-
if tracker.name == "wandb":
971-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
972-
for edited_image in edited_images:
973-
wandb_table.add_data(
974-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
975-
)
976-
tracker.log({"validation": wandb_table})
981+
982+
log_validation(
983+
pipeline,
984+
args,
985+
accelerator,
986+
generator,
987+
)
988+
977989
if args.use_ema:
978990
# Switch back to the original UNet parameters.
979991
ema_unet.restore(unet.parameters())
@@ -984,15 +996,14 @@ def collate_fn(examples):
984996
# Create the pipeline using the trained modules and save it.
985997
accelerator.wait_for_everyone()
986998
if accelerator.is_main_process:
987-
unet = unwrap_model(unet)
988999
if args.use_ema:
9891000
ema_unet.copy_to(unet.parameters())
9901001

9911002
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
9921003
args.pretrained_model_name_or_path,
9931004
text_encoder=unwrap_model(text_encoder),
9941005
vae=unwrap_model(vae),
995-
unet=unet,
1006+
unet=unwrap_model(unet),
9961007
revision=args.revision,
9971008
variant=args.variant,
9981009
)
@@ -1006,31 +1017,13 @@ def collate_fn(examples):
10061017
ignore_patterns=["step_*", "epoch_*"],
10071018
)
10081019

1009-
if args.validation_prompt is not None:
1010-
edited_images = []
1011-
pipeline = pipeline.to(accelerator.device)
1012-
with torch.autocast(str(accelerator.device).replace(":0", "")):
1013-
for _ in range(args.num_validation_images):
1014-
edited_images.append(
1015-
pipeline(
1016-
args.validation_prompt,
1017-
image=original_image,
1018-
num_inference_steps=20,
1019-
image_guidance_scale=1.5,
1020-
guidance_scale=7,
1021-
generator=generator,
1022-
).images[0]
1023-
)
1024-
1025-
for tracker in accelerator.trackers:
1026-
if tracker.name == "wandb":
1027-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1028-
for edited_image in edited_images:
1029-
wandb_table.add_data(
1030-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1031-
)
1032-
tracker.log({"test": wandb_table})
1033-
1020+
if (args.val_image_url is not None) and (args.validation_prompt is not None):
1021+
log_validation(
1022+
pipeline,
1023+
args,
1024+
accelerator,
1025+
generator,
1026+
)
10341027
accelerator.end_training()
10351028

10361029

0 commit comments

Comments
 (0)