Guest User

Untitled

a guest
Jan 15th, 2025
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.66 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from transformers import T5Model, T5Tokenizer
  4. from typing import Optional, Tuple
  5.  
  6. class T5ToSDXLAdapter(nn.Module):
  7. def __init__(
  8. self,
  9. t5_model_name: str = "t5-base",
  10. sdxl_embed_dim: int = 1280, # SDXL default text embedding dimension
  11. max_length: int = 77, # SDXL default sequence length
  12. ):
  13. super().__init__()
  14. self.t5 = T5Model.from_pretrained(t5_model_name)
  15. self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
  16.  
  17. # T5-base hidden size is 768, adapt to SDXL's 1280
  18. self.projection = nn.Sequential(
  19. nn.Linear(self.t5.config.hidden_size, sdxl_embed_dim),
  20. nn.LayerNorm(sdxl_embed_dim)
  21. )
  22.  
  23. self.max_length = max_length
  24. self.sdxl_embed_dim = sdxl_embed_dim
  25.  
  26. def forward(
  27. self,
  28. text: list[str],
  29. negative_prompt: Optional[list[str]] = None
  30. ) -> Tuple[torch.Tensor, torch.Tensor]:
  31. # Tokenize and encode text
  32. tokenized = self.t5_tokenizer(
  33. text,
  34. padding="max_length",
  35. max_length=self.max_length,
  36. truncation=True,
  37. return_tensors="pt"
  38. )
  39.  
  40. # Get T5 embeddings
  41. t5_output = self.t5(
  42. input_ids=tokenized.input_ids,
  43. attention_mask=tokenized.attention_mask,
  44. return_dict=True
  45. )
  46.  
  47. # Project to SDXL dimension
  48. cond_embeds = self.projection(t5_output.last_hidden_state)
  49.  
  50. # Handle negative prompts
  51. if negative_prompt is not None:
  52. neg_tokenized = self.t5_tokenizer(
  53. negative_prompt,
  54. padding="max_length",
  55. max_length=self.max_length,
  56. truncation=True,
  57. return_tensors="pt"
  58. )
  59.  
  60. neg_output = self.t5(
  61. input_ids=neg_tokenized.input_ids,
  62. attention_mask=neg_tokenized.attention_mask,
  63. return_dict=True
  64. )
  65. uncond_embeds = self.projection(neg_output.last_hidden_state)
  66. else:
  67. # Create empty embedding if no negative prompt
  68. uncond_embeds = torch.zeros(
  69. (cond_embeds.shape[0], self.max_length, self.sdxl_embed_dim),
  70. device=cond_embeds.device
  71. )
  72.  
  73. return cond_embeds, uncond_embeds
  74.  
  75. def freeze_t5(self):
  76. """Freeze T5 parameters to only train the projection layer"""
  77. for param in self.t5.parameters():
  78. param.requires_grad = False
Advertisement
Add Comment
Please, Sign In to add comment