Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from transformers import T5Model, T5Tokenizer
- from typing import Optional, Tuple
- class T5ToSDXLAdapter(nn.Module):
- def __init__(
- self,
- t5_model_name: str = "t5-base",
- sdxl_embed_dim: int = 1280, # SDXL default text embedding dimension
- max_length: int = 77, # SDXL default sequence length
- ):
- super().__init__()
- self.t5 = T5Model.from_pretrained(t5_model_name)
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
- # T5-base hidden size is 768, adapt to SDXL's 1280
- self.projection = nn.Sequential(
- nn.Linear(self.t5.config.hidden_size, sdxl_embed_dim),
- nn.LayerNorm(sdxl_embed_dim)
- )
- self.max_length = max_length
- self.sdxl_embed_dim = sdxl_embed_dim
- def forward(
- self,
- text: list[str],
- negative_prompt: Optional[list[str]] = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Tokenize and encode text
- tokenized = self.t5_tokenizer(
- text,
- padding="max_length",
- max_length=self.max_length,
- truncation=True,
- return_tensors="pt"
- )
- # Get T5 embeddings
- t5_output = self.t5(
- input_ids=tokenized.input_ids,
- attention_mask=tokenized.attention_mask,
- return_dict=True
- )
- # Project to SDXL dimension
- cond_embeds = self.projection(t5_output.last_hidden_state)
- # Handle negative prompts
- if negative_prompt is not None:
- neg_tokenized = self.t5_tokenizer(
- negative_prompt,
- padding="max_length",
- max_length=self.max_length,
- truncation=True,
- return_tensors="pt"
- )
- neg_output = self.t5(
- input_ids=neg_tokenized.input_ids,
- attention_mask=neg_tokenized.attention_mask,
- return_dict=True
- )
- uncond_embeds = self.projection(neg_output.last_hidden_state)
- else:
- # Create empty embedding if no negative prompt
- uncond_embeds = torch.zeros(
- (cond_embeds.shape[0], self.max_length, self.sdxl_embed_dim),
- device=cond_embeds.device
- )
- return cond_embeds, uncond_embeds
- def freeze_t5(self):
- """Freeze T5 parameters to only train the projection layer"""
- for param in self.t5.parameters():
- param.requires_grad = False
Advertisement
Add Comment
Please, Sign In to add comment