zakhar_azg

Untitled

Nov 15th, 2024
21
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.97 KB | None | 0 0
  1. from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
  2. from datasets import load_dataset
  3. from sentence_transformers import SentenceTransformer
  4. from sklearn.metrics.pairwise import cosine_similarity
  5. import torch
  6.  
  7. loaded_dataset = load_dataset("squad_v2")
  8. dataset = [item['question'] for item in loaded_dataset['train']]
  9.  
  10. def find_similar_texts(query, dataset, model, top_n=3):
  11. dataset_embeddings = model.encode(dataset)
  12. query_embedding = model.encode([query])
  13.  
  14. similarities = cosine_similarity(query_embedding, dataset_embeddings).flatten()
  15.  
  16. top_n_indices = similarities.argsort()[-top_n:][::-1]
  17.  
  18. return [(dataset[i], similarities[i]) for i in top_n_indices]
  19.  
  20. def enhance_prompt_with_retrieved_text(query, dataset, model, top_n=3):
  21. similar_texts = find_similar_texts(query, dataset, model, top_n)
  22.  
  23. retrieved_texts = "\n".join(f"Relevant code {i+1}: {text}" for i, (text, _) in enumerate(similar_texts))
  24.  
  25. enhanced_prompt = f"User query: {query}\n\n" \
  26. f"Relevant to user query chunks:\n{retrieved_texts}\n\n" \
  27. f"Answer for user query, please"
  28.  
  29. return enhanced_prompt
  30.  
  31.  
  32. if __name__ == "__main__":
  33.  
  34. model_name = "bigscience/bloom-560m"
  35. tokenizer = AutoTokenizer.from_pretrained(model_name)
  36. model = AutoModelForCausalLM.from_pretrained(model_name)
  37. generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
  38.  
  39. embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
  40.  
  41. user_query = "how to make money on cryptocurrency?"
  42. promt_with_context = enhance_prompt_with_retrieved_text(user_query, dataset, embed_model)
  43.  
  44. print(promt_with_context)
  45.  
  46. output = generator(
  47. promt_with_context,
  48. max_length=2048,
  49. num_return_sequences=1,
  50. temperature=0.7,
  51. top_p=0.9,
  52. repetition_penalty=1.2
  53. )
  54. print(output[0]['generated_text'])
  55.  
Advertisement
Add Comment
Please, Sign In to add comment