Guest User

Untitled

a guest
Jul 9th, 2024
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.74 KB | None | 0 0
  1. # %%
  2. filename = model_name.split("/")[1] + "_tokens.pt"
  3. if os.path.exists(filename):
  4. tokens = torch.load(filename)
  5. else:
  6. tokens = set(tokenizer._batch_encode_plus(
  7. [prompt.format(text,positivelabel) for text in data["text"]] + [negativelabel],
  8. padding_strategy="max_length",
  9. add_special_tokens=False, return_tensors="pt").input_ids.numpy().reshape(-1).tolist())
  10. torch.save(tokens, filename)
  11.  
  12. # %%
  13. print("unique tokens:", len(tokens))
  14.  
  15. # %%
  16. # modify model.model.embed_tokens.weight so that only the tokens in the dataset are kept
  17. # model.base_model.model.model.embed_tokens is a Embedding(152064, 3584)
  18. from torch.nn import Embedding
  19.  
  20. class CustomEmbedding(Embedding):
  21. def __init__(self, num_embeddings, embedding_dim, old_to_new_indices, **kwargs):
  22. super().__init__(num_embeddings, embedding_dim, **kwargs)
  23. self.old_to_new_indices = old_to_new_indices
  24. # self.default_index = num_embeddings # Index for unknown tokens
  25. self.default_index = 0
  26.  
  27. def forward(self, input):
  28. # Map old indices to new indices
  29. new_input = input.clone()
  30. for old_idx in torch.unique(input):
  31. if old_idx.item() in self.old_to_new_indices:
  32. new_input[input == old_idx] = self.old_to_new_indices[old_idx.item()]
  33. else:
  34. new_input[input == old_idx] = self.default_index
  35. print(f"Token {old_idx.item()} not found in the dataset, using default index {self.default_index}")
  36. print("obtained input", input)
  37. print("obtained input text", tokenizer.decode(input.cpu().numpy().reshape(-1).tolist()))
  38. return super().forward(new_input)
  39.  
  40. # Create a mapping from old token indices to new token indices
  41. old_to_new_indices = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(tokens))}
  42.  
  43. # Create a new embedding layer with the smaller weights tensor
  44. new_embedding = CustomEmbedding(len(tokens) + 1, model.base_model.model.model.embed_tokens.embedding_dim, old_to_new_indices)
  45.  
  46. # Copy the relevant weights from the old embedding to the new embedding
  47. for old_idx, new_idx in old_to_new_indices.items():
  48. new_embedding.weight.data[new_idx] = model.base_model.model.model.embed_tokens.weight.data[old_idx]
  49.  
  50. # Initialize the weight for the default index (unknown tokens) to zero or some other value
  51. new_embedding.weight.data[-1] = torch.zeros_like(new_embedding.weight.data[0])
  52.  
  53. # Replace the old embedding layer with the new one
  54. model.base_model.model.model.embed_tokens = new_embedding.cuda()
  55.  
  56. # %%
  57. # print(list(tokens)[-2:])
  58. # print(model.base_model.model.model.embed_tokens(torch.ones(1).long().cuda()*98810))
  59. # print(model.base_model.model.model.embed_tokens.weight.shape)
Advertisement
Add Comment
Please, Sign In to add comment