Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Add this to modules/prompt_parser.py:
- import torch
- from torch import nn
- class VectorAdjustPrior(nn.Module):
- def __init__(self, hidden_size, inter_dim=64):
- super().__init__()
- self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
- self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
- def forward(self, z):
- b, s = z.shape[0:2]
- x1 = torch.mean(z, dim=1).repeat(s, 1)
- x2 = z.reshape(b*s, -1)
- x = torch.cat((x1, x2), dim=1)
- x = self.vector_proj(x)
- x = torch.cat((x2, x), dim=1)
- x = self.out_proj(x)
- x = x.reshape(b, s, -1)
- return x
- @classmethod
- def load_model(cls, model_path, hidden_size=768, inter_dim=64):
- model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
- model.load_state_dict(torch.load(model_path)["state_dict"])
- return model
- vap = VectorAdjustPrior.load_model('v2.pt')
- # then find def get_learned_conditioning(model, prompts, steps):
- # and add one line inside it:
- conds = vap(conds)
- # this line must be placed here:
- texts = [x[1] for x in prompt_schedule]
- conds = model.get_learned_conditioning(texts)
- conds = vap(conds)
- cond_schedule = []
- for i, (end_at_step, text) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
Advertisement
Add Comment
Please, Sign In to add comment