Guest User

Untitled

a guest
Oct 7th, 2022
2,907
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.44 KB | None | 0 0
  1. # Add this to modules/prompt_parser.py:
  2.  
  3. import torch
  4. from torch import nn
  5.  
  6. class VectorAdjustPrior(nn.Module):
  7. def __init__(self, hidden_size, inter_dim=64):
  8. super().__init__()
  9. self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
  10. self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
  11.  
  12. def forward(self, z):
  13. b, s = z.shape[0:2]
  14. x1 = torch.mean(z, dim=1).repeat(s, 1)
  15. x2 = z.reshape(b*s, -1)
  16. x = torch.cat((x1, x2), dim=1)
  17. x = self.vector_proj(x)
  18. x = torch.cat((x2, x), dim=1)
  19. x = self.out_proj(x)
  20. x = x.reshape(b, s, -1)
  21. return x
  22.  
  23. @classmethod
  24. def load_model(cls, model_path, hidden_size=768, inter_dim=64):
  25. model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
  26. model.load_state_dict(torch.load(model_path)["state_dict"])
  27. return model
  28.  
  29.  
  30. vap = VectorAdjustPrior.load_model('v2.pt')
  31.  
  32.  
  33. # then find def get_learned_conditioning(model, prompts, steps):
  34. # and add one line inside it:
  35.  
  36. conds = vap(conds)
  37.  
  38.  
  39. # this line must be placed here:
  40.  
  41. texts = [x[1] for x in prompt_schedule]
  42. conds = model.get_learned_conditioning(texts)
  43.  
  44. conds = vap(conds)
  45.  
  46. cond_schedule = []
  47. for i, (end_at_step, text) in enumerate(prompt_schedule):
  48. cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
  49.  
Advertisement
Add Comment
Please, Sign In to add comment