Skip to content

Commit

Permalink
add support for hunyuan video
Browse files Browse the repository at this point in the history
  • Loading branch information
Nerogar committed Jan 4, 2025
1 parent e6bd96b commit 972ccb1
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 10 deletions.
80 changes: 80 additions & 0 deletions src/mgds/pipelineModules/EncodeLlamaText.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from contextlib import nullcontext

import torch
from transformers import LlamaModel

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule


class EncodeLlamaText(
PipelineModule,
RandomAccessPipelineModule,
):
def __init__(
self,
tokens_in_name: str,
tokens_attention_mask_in_name: str | None,
hidden_state_out_name: str,
tokens_attention_mask_out_name: str | None,
text_encoder: LlamaModel,
hidden_state_output_index: int | None = None,
crop_start: int | None = None,
autocast_contexts: list[torch.autocast | None] = None,
dtype: torch.dtype | None = None,
):
super(EncodeLlamaText, self).__init__()
self.tokens_in_name = tokens_in_name
self.tokens_attention_mask_in_name = tokens_attention_mask_in_name
self.hidden_state_out_name = hidden_state_out_name
self.tokens_attention_mask_out_name = tokens_attention_mask_out_name
self.text_encoder = text_encoder
self.hidden_state_output_index = hidden_state_output_index
self.crop_start = crop_start

self.autocast_contexts = [nullcontext()] if autocast_contexts is None else autocast_contexts
self.dtype = dtype

def length(self) -> int:
return self._get_previous_length(self.tokens_in_name)

def get_inputs(self) -> list[str]:
return [self.tokens_in_name, self.tokens_attention_mask_in_name]

def get_outputs(self) -> list[str]:
return [self.hidden_state_out_name, self.tokens_attention_mask_out_name]

def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
tokens = self._get_previous_item(variation, self.tokens_in_name, index)
tokens = tokens.unsqueeze(0)

if self.tokens_attention_mask_in_name is not None:
tokens_attention_mask = self._get_previous_item(variation, self.tokens_attention_mask_in_name, index)
tokens_attention_mask = tokens_attention_mask.unsqueeze(0)
else:
tokens_attention_mask = None

with self._all_contexts(self.autocast_contexts):
if tokens_attention_mask is not None and self.dtype:
tokens_attention_mask = tokens_attention_mask.to(dtype=self.dtype)

text_encoder_output = self.text_encoder(
tokens,
attention_mask=tokens_attention_mask,
output_hidden_states=True,
return_dict=True,
)

hidden_states = text_encoder_output.hidden_states
hidden_states = [hidden_state.squeeze() for hidden_state in hidden_states]
hidden_state = hidden_states[self.hidden_state_output_index]
tokens_attention_mask = tokens_attention_mask.squeeze()

if self.crop_start is not None:
hidden_state = hidden_state[self.crop_start:]
tokens_attention_mask = tokens_attention_mask[self.crop_start:]

return {
self.hidden_state_out_name: hidden_state,
self.tokens_attention_mask_out_name: tokens_attention_mask,
}
3 changes: 2 additions & 1 deletion src/mgds/pipelineModules/EncodeVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from diffusers import AutoencoderKL, AutoencoderDC
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule
Expand All @@ -15,7 +16,7 @@ def __init__(
self,
in_name: str,
out_name: str,
vae: AutoencoderKL | AutoencoderDC,
vae: AutoencoderKL | AutoencoderDC | AutoencoderKLHunyuanVideo,
autocast_contexts: list[torch.autocast | None] = None,
dtype: torch.dtype | None = None,
):
Expand Down
37 changes: 37 additions & 0 deletions src/mgds/pipelineModules/ImageToVideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast, GemmaTokenizer, LlamaTokenizer

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule


class ImageToVideo(
PipelineModule,
RandomAccessPipelineModule,
):
def __init__(
self,
in_name: str,
out_name: str,
):
super(ImageToVideo, self).__init__()
self.in_name = in_name
self.out_name = out_name

def length(self) -> int:
return self._get_previous_length(self.in_name)

def get_inputs(self) -> list[str]:
return [self.in_name]

def get_outputs(self) -> list[str]:
return [self.out_name]

def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
tensor = self._get_previous_item(variation, self.in_name, index)

if tensor.ndim == 3:
tensor = tensor.unsqueeze(1)

return {
self.out_name: tensor,
}
32 changes: 23 additions & 9 deletions src/mgds/pipelineModules/Tokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast, GemmaTokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast, GemmaTokenizer, LlamaTokenizer

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule
Expand All @@ -13,15 +13,19 @@ def __init__(
in_name: str,
tokens_out_name: str,
mask_out_name: str,
tokenizer: CLIPTokenizer | T5Tokenizer | T5TokenizerFast | GemmaTokenizer,
tokenizer: CLIPTokenizer | T5Tokenizer | T5TokenizerFast | GemmaTokenizer | LlamaTokenizer,
max_token_length: int,
format_text: str | None = None,
additional_format_text_tokens: int | None = None,
):
super(Tokenize, self).__init__()
self.in_name = in_name
self.tokens_out_name = tokens_out_name
self.mask_out_name = mask_out_name
self.tokenizer = tokenizer
self.max_token_length = max_token_length
self.format_text = format_text
self.additional_format_text_tokens = additional_format_text_tokens

def length(self) -> int:
return self._get_previous_length(self.in_name)
Expand All @@ -35,13 +39,23 @@ def get_outputs(self) -> list[str]:
def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
text = self._get_previous_item(variation, self.in_name, index)

tokenizer_output = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=self.max_token_length,
return_tensors="pt",
)
if self.format_text is not None:
text = self.format_text.format(text)
tokenizer_output = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=self.max_token_length + self.additional_format_text_tokens,
return_tensors="pt",
)
else:
tokenizer_output = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=self.max_token_length,
return_tensors="pt",
)

tokens = tokenizer_output.input_ids.to(self.pipeline.device)
mask = tokenizer_output.attention_mask.to(self.pipeline.device)
Expand Down

0 comments on commit 972ccb1

Please sign in to comment.