Skip to content

Commit f9be58d

Browse files
yiyixuxuokotakusayakpaulyiyixuxupatrickvonplaten
authored
[feat] IP Adapters (author @okotaku ) (huggingface#5713)
* add ip-adapter --------- Co-authored-by: okotaku <to78314910@gmail.com> Co-authored-by: sayakpaul <spsayakpaul@gmail.com> Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 6583599 commit f9be58d

20 files changed

+972
-58
lines changed

Diff for: loaders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def text_encoder_attn_modules(text_encoder):
6262
_import_structure["single_file"].extend(["FromSingleFileMixin"])
6363
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
6464
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
65+
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
6566

6667

6768
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -72,6 +73,7 @@ def text_encoder_attn_modules(text_encoder):
7273
from .utils import AttnProcsLayers
7374

7475
if is_transformers_available():
76+
from .ip_adapter import IPAdapterMixin
7577
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
7678
from .single_file import FromSingleFileMixin
7779
from .textual_inversion import TextualInversionLoaderMixin

Diff for: loaders/ip_adapter.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from typing import Dict, Union
16+
17+
import torch
18+
from safetensors import safe_open
19+
20+
from ..utils import (
21+
DIFFUSERS_CACHE,
22+
HF_HUB_OFFLINE,
23+
_get_model_file,
24+
is_transformers_available,
25+
logging,
26+
)
27+
28+
29+
if is_transformers_available():
30+
from transformers import (
31+
CLIPImageProcessor,
32+
CLIPVisionModelWithProjection,
33+
)
34+
35+
from ..models.attention_processor import (
36+
IPAdapterAttnProcessor,
37+
IPAdapterAttnProcessor2_0,
38+
)
39+
40+
logger = logging.get_logger(__name__)
41+
42+
43+
class IPAdapterMixin:
44+
"""Mixin for handling IP Adapters."""
45+
46+
def load_ip_adapter(
47+
self,
48+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
49+
subfolder: str,
50+
weight_name: str,
51+
**kwargs,
52+
):
53+
"""
54+
Parameters:
55+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
56+
Can be either:
57+
58+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
59+
the Hub.
60+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
61+
with [`ModelMixin.save_pretrained`].
62+
- A [torch state
63+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
64+
65+
cache_dir (`Union[str, os.PathLike]`, *optional*):
66+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
67+
is not used.
68+
force_download (`bool`, *optional*, defaults to `False`):
69+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
70+
cached versions if they exist.
71+
resume_download (`bool`, *optional*, defaults to `False`):
72+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
73+
incompletely downloaded files are deleted.
74+
proxies (`Dict[str, str]`, *optional*):
75+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
76+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
77+
local_files_only (`bool`, *optional*, defaults to `False`):
78+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
79+
won't be downloaded from the Hub.
80+
use_auth_token (`str` or *bool*, *optional*):
81+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
82+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
83+
revision (`str`, *optional*, defaults to `"main"`):
84+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
85+
allowed by Git.
86+
subfolder (`str`, *optional*, defaults to `""`):
87+
The subfolder location of a model file within a larger model repository on the Hub or locally.
88+
"""
89+
90+
# Load the main state dict first.
91+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
92+
force_download = kwargs.pop("force_download", False)
93+
resume_download = kwargs.pop("resume_download", False)
94+
proxies = kwargs.pop("proxies", None)
95+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
96+
use_auth_token = kwargs.pop("use_auth_token", None)
97+
revision = kwargs.pop("revision", None)
98+
99+
user_agent = {
100+
"file_type": "attn_procs_weights",
101+
"framework": "pytorch",
102+
}
103+
104+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
105+
model_file = _get_model_file(
106+
pretrained_model_name_or_path_or_dict,
107+
weights_name=weight_name,
108+
cache_dir=cache_dir,
109+
force_download=force_download,
110+
resume_download=resume_download,
111+
proxies=proxies,
112+
local_files_only=local_files_only,
113+
use_auth_token=use_auth_token,
114+
revision=revision,
115+
subfolder=subfolder,
116+
user_agent=user_agent,
117+
)
118+
if weight_name.endswith(".safetensors"):
119+
state_dict = {"image_proj": {}, "ip_adapter": {}}
120+
with safe_open(model_file, framework="pt", device="cpu") as f:
121+
for key in f.keys():
122+
if key.startswith("image_proj."):
123+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
124+
elif key.startswith("ip_adapter."):
125+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
126+
else:
127+
state_dict = torch.load(model_file, map_location="cpu")
128+
else:
129+
state_dict = pretrained_model_name_or_path_or_dict
130+
131+
keys = list(state_dict.keys())
132+
if keys != ["image_proj", "ip_adapter"]:
133+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
134+
135+
# load CLIP image encoer here if it has not been registered to the pipeline yet
136+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
137+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
138+
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
139+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
140+
pretrained_model_name_or_path_or_dict,
141+
subfolder=os.path.join(subfolder, "image_encoder"),
142+
).to(self.device, dtype=self.dtype)
143+
self.image_encoder = image_encoder
144+
else:
145+
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
146+
147+
# create feature extractor if it has not been registered to the pipeline yet
148+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
149+
self.feature_extractor = CLIPImageProcessor()
150+
151+
# load ip-adapter into unet
152+
self.unet._load_ip_adapter_weights(state_dict)
153+
154+
def set_ip_adapter_scale(self, scale):
155+
for attn_processor in self.unet.attn_processors.values():
156+
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
157+
attn_processor.scale = scale

Diff for: loaders/unet.py

+70
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
import safetensors
2020
import torch
21+
import torch.nn.functional as F
2122
from torch import nn
2223

24+
from ..models.embeddings import ImageProjection
2325
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2426
from ..utils import (
2527
DIFFUSERS_CACHE,
@@ -662,4 +664,72 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
662664
if hasattr(self, "peft_config"):
663665
self.peft_config.pop(adapter_name, None)
664666

667+
def _load_ip_adapter_weights(self, state_dict):
668+
from ..models.attention_processor import (
669+
AttnProcessor,
670+
AttnProcessor2_0,
671+
IPAdapterAttnProcessor,
672+
IPAdapterAttnProcessor2_0,
673+
)
674+
675+
# set ip-adapter cross-attention processors & load state_dict
676+
attn_procs = {}
677+
key_id = 1
678+
for name in self.attn_processors.keys():
679+
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
680+
if name.startswith("mid_block"):
681+
hidden_size = self.config.block_out_channels[-1]
682+
elif name.startswith("up_blocks"):
683+
block_id = int(name[len("up_blocks.")])
684+
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
685+
elif name.startswith("down_blocks"):
686+
block_id = int(name[len("down_blocks.")])
687+
hidden_size = self.config.block_out_channels[block_id]
688+
if cross_attention_dim is None or "motion_modules" in name:
689+
attn_processor_class = (
690+
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
691+
)
692+
attn_procs[name] = attn_processor_class()
693+
else:
694+
attn_processor_class = (
695+
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
696+
)
697+
attn_procs[name] = attn_processor_class(
698+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
699+
).to(dtype=self.dtype, device=self.device)
700+
701+
value_dict = {}
702+
for k, w in attn_procs[name].state_dict().items():
703+
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
704+
705+
attn_procs[name].load_state_dict(value_dict)
706+
key_id += 2
707+
708+
self.set_attn_processor(attn_procs)
709+
710+
# create image projection layers.
711+
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
712+
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
713+
714+
image_projection = ImageProjection(
715+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
716+
)
717+
image_projection.to(dtype=self.dtype, device=self.device)
718+
719+
# load image projection layer weights
720+
image_proj_state_dict = {}
721+
image_proj_state_dict.update(
722+
{
723+
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
724+
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
725+
"norm.weight": state_dict["image_proj"]["norm.weight"],
726+
"norm.bias": state_dict["image_proj"]["norm.bias"],
727+
}
728+
)
729+
730+
image_projection.load_state_dict(image_proj_state_dict)
731+
732+
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
733+
self.config.encoder_hid_dim_type = "ip_image_proj"
734+
665735
delete_adapter_layers

0 commit comments

Comments
 (0)