Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # %%
- filename = model_name.split("/")[1] + "_tokens.pt"
- if os.path.exists(filename):
- tokens = torch.load(filename)
- else:
- tokens = set(tokenizer._batch_encode_plus(
- [prompt.format(text,positivelabel) for text in data["text"]] + [negativelabel],
- padding_strategy="max_length",
- add_special_tokens=False, return_tensors="pt").input_ids.numpy().reshape(-1).tolist())
- torch.save(tokens, filename)
- # %%
- print("unique tokens:", len(tokens))
- # %%
- # modify model.model.embed_tokens.weight so that only the tokens in the dataset are kept
- # model.base_model.model.model.embed_tokens is a Embedding(152064, 3584)
- from torch.nn import Embedding
- class CustomEmbedding(Embedding):
- def __init__(self, num_embeddings, embedding_dim, old_to_new_indices, **kwargs):
- super().__init__(num_embeddings, embedding_dim, **kwargs)
- self.old_to_new_indices = old_to_new_indices
- # self.default_index = num_embeddings # Index for unknown tokens
- self.default_index = 0
- def forward(self, input):
- # Map old indices to new indices
- new_input = input.clone()
- for old_idx in torch.unique(input):
- if old_idx.item() in self.old_to_new_indices:
- new_input[input == old_idx] = self.old_to_new_indices[old_idx.item()]
- else:
- new_input[input == old_idx] = self.default_index
- print(f"Token {old_idx.item()} not found in the dataset, using default index {self.default_index}")
- print("obtained input", input)
- print("obtained input text", tokenizer.decode(input.cpu().numpy().reshape(-1).tolist()))
- return super().forward(new_input)
- # Create a mapping from old token indices to new token indices
- old_to_new_indices = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(tokens))}
- # Create a new embedding layer with the smaller weights tensor
- new_embedding = CustomEmbedding(len(tokens) + 1, model.base_model.model.model.embed_tokens.embedding_dim, old_to_new_indices)
- # Copy the relevant weights from the old embedding to the new embedding
- for old_idx, new_idx in old_to_new_indices.items():
- new_embedding.weight.data[new_idx] = model.base_model.model.model.embed_tokens.weight.data[old_idx]
- # Initialize the weight for the default index (unknown tokens) to zero or some other value
- new_embedding.weight.data[-1] = torch.zeros_like(new_embedding.weight.data[0])
- # Replace the old embedding layer with the new one
- model.base_model.model.model.embed_tokens = new_embedding.cuda()
- # %%
- # print(list(tokens)[-2:])
- # print(model.base_model.model.model.embed_tokens(torch.ones(1).long().cuda()*98810))
- # print(model.base_model.model.model.embed_tokens.weight.shape)
Advertisement
Add Comment
Please, Sign In to add comment