Advertisement
Guest User

Application

a guest
Aug 22nd, 2024
62
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.65 KB | None | 0 0
  1. import os
  2. import torch
  3. from diffusers import StableDiffusionPipeline
  4. import PyPDF2
  5. from PIL import Image
  6. import io
  7. import streamlit as st
  8. import time
  9. from functools import wraps
  10. from requests.exceptions import ChunkedEncodingError
  11. from transformers import BartTokenizerFast, BartForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
  12.  
  13.  
  14. # Load the saved tokenizer and model for text summarization
  15. tokenizer = AutoTokenizer.from_pretrained("best_saved_token_10")
  16. model = AutoModelForSeq2SeqLM.from_pretrained("best_saved_model_10")
  17.  
  18. # Function to extract text from a PDF file
  19. def extract_text_from_pdf(pdf_path):
  20. ''' Extract text from a PDF file '''
  21. with open(pdf_path, 'rb') as file:
  22. reader = PyPDF2.PdfReader(file)
  23. text = ''
  24. for page in reader.pages:
  25. page_text = page.extract_text()
  26. if page_text:
  27. text += page_text + '\n'
  28. return text
  29.  
  30. # Function to summarize a given text
  31. def summarize_text(text, max_length=100):
  32. ''' Summarize the given text using the loaded model '''
  33. input_ids = tokenizer.encode(text, return_tensors='pt')
  34. output_ids = model.generate(input_ids, max_length=1000, num_beams=4, length_penalty=2.0, early_stopping=True)
  35. summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
  36. return summary
  37.  
  38. # Load the model for text to image generation
  39. model_id = "stabilityai/stable-diffusion-xl-base-1.0"
  40. model_id = 'saved_image_model'
  41. def setup_mps_device():
  42. if torch.backends.mps.is_available():
  43. mps_device = torch.device("mps")
  44. return mps_device
  45. else:
  46. st.warning("MPS device is not available.")
  47. return torch.device("cpu")
  48.  
  49. def retry_on_chunked_encoding_error(max_retries=3, delay=1):
  50. def decorator(func):
  51. @wraps(func)
  52. def wrapper(*args, **kwargs):
  53. retries = 0
  54. while retries < max_retries:
  55. try:
  56. return func(*args, **kwargs)
  57. except ChunkedEncodingError as e:
  58. retries += 1
  59. if retries == max_retries:
  60. st.warning(f"Failed to generate image after {max_retries} attempts.")
  61. return None
  62. time.sleep(delay * retries) # Exponential backoff
  63. return None # This line should never be reached
  64. return wrapper
  65. return decorator
  66.  
  67. @retry_on_chunked_encoding_error(max_retries=3, delay=1)
  68. # Function to generate an image from a given text
  69. def image_generation(text):
  70. ''' Generate an image from the given text using the loaded model '''
  71. pipeline = DiffusionPipeline.from_pretrained(model_id)
  72. device = setup_mps_device()
  73. pipeline.to(device)
  74. with torch.inference_mode():
  75. image = pipeline(text,
  76. height = 1024,
  77. width = 1024,
  78. guidance_scale= 7.0,
  79. num_inference_steps=50).images[0]
  80.  
  81. return image
  82.  
  83. # class LocalImageGenerator:
  84. # def __init__(self, model_path):
  85. # self.model_path = os.path.abspath(model_path)
  86. # self.pipeline = None
  87. # self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
  88.  
  89. # def load_model(self):
  90. # if self.pipeline is None:
  91. # if not os.path.exists(self.model_path):
  92. # raise FileNotFoundError(f"Model path does not exist: {self.model_path}")
  93.  
  94. # try:
  95. # self.pipeline = StableDiffusionPipeline.from_pretrained(
  96. # self.model_path,
  97. # torch_dtype=torch.float32 if self.device == "mps" else torch.float16,
  98. # local_files_only=True
  99. # )
  100. # self.pipeline = self.pipeline.to(self.device)
  101. # except Exception as e:
  102. # raise Exception(f"Failed to load model from {self.model_path}: {str(e)}")
  103.  
  104. # def generate_image(self, prompt, height=128, width=128, num_inference_steps=50, guidance_scale=7.5):
  105. # if self.pipeline is None:
  106. # self.load_model()
  107.  
  108. # try:
  109. # with torch.no_grad():
  110. # image = self.pipeline(prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
  111. # return image
  112. # except Exception as e:
  113. # raise Exception(f"Error during image generation: {str(e)}")
  114.  
  115. # # Initialize the generator with the path to your locally saved model
  116. # image_generator = LocalImageGenerator("saved_image_model_6.5_50")
  117.  
  118. # def image_generation(text):
  119. # try:
  120. # image = image_generator.generate_image(text)
  121. # return image
  122. # except Exception as e:
  123. # st.error(f"An error occurred during image generation: {str(e)}")
  124. # return None
  125.  
  126. # Streamlit UI
  127. st.set_page_config(page_title="Easy Read Document")
  128.  
  129. # Custom CSS
  130. st.markdown("""
  131. <style>
  132. @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&family=Poppins:wght@700&display=swap');
  133. @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css');
  134.  
  135. :root {
  136. --bg-color: #f5f5f5;
  137. --text-color: #333;
  138. --input-bg: #ffffff;
  139. --border-color: #e0e0e0;
  140. --chat-bg: #ffffff;
  141. --file-details-bg: #f0f0f0;
  142. }
  143.  
  144. .stApp {
  145. font-family: 'Roboto', sans-serif;
  146. background-color: var(--background-color);
  147. color: var(--text-color);
  148. }
  149.  
  150. .header {
  151. display: flex;
  152. justify-content: space-between;
  153. align-items: center;
  154. padding: 20px;
  155. background-color: var(--input-bg);
  156. border-bottom: 1px solid var(--border-color);
  157. }
  158.  
  159. .logo {
  160. width: 100px;
  161. }
  162.  
  163. .title {
  164. font-family: 'Poppins', sans-serif;
  165. font-size: 28px;
  166. font-weight: 700;
  167. text-align: center;
  168. background: linear-gradient(45deg, #007BFF, #00BFFF);
  169. -webkit-background-clip: text;
  170. -webkit-text-fill-color: transparent;
  171. }
  172.  
  173. .file-details {
  174. background-color: var(--file-details-bg);
  175. padding: 15px;
  176. border-radius: 5px;
  177. box-shadow: 0 2px 5px rgba(0,0,0,0.1);
  178. color: var(--text-color);
  179. margin-top: 20px;
  180. }
  181.  
  182. .stTextInput > div > div > input {
  183. background-color: var(--input-bg);
  184. color: var(--text-color);
  185. }
  186.  
  187. .stButton > button {
  188. background-color: #007BFF;
  189. color: white;
  190. }
  191.  
  192. .output {
  193. font-family: Arial, sans-serif;
  194. white-space: pre-wrap;
  195. }
  196. </style>
  197. """, unsafe_allow_html=True)
  198.  
  199. # Header
  200. # st.markdown("""
  201. # <div class="header">
  202. # <img src="static/USW.png" alt="University Logo" class="logo">
  203. # <h1 class="title">Easy Read Document</h1>
  204. # </div>
  205. # """, unsafe_allow_html=True
  206.  
  207. st.image("static/USW.png", width=100)
  208. st.markdown("<h1 style='text-align: center;'>Easy Read Document</h1>", unsafe_allow_html=True)
  209.  
  210. # File uploader
  211. uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
  212.  
  213. if uploaded_file is not None:
  214. # File details
  215. st.markdown(f"""
  216. <div class="file-details">
  217. <h2>Uploaded file details:</h2>
  218. <p>Filename: {uploaded_file.name}</p>
  219. </div>
  220. """, unsafe_allow_html=True)
  221.  
  222. # Process and summarize the PDF
  223. with st.spinner("Processing..."):
  224. # Save the uploaded file
  225. file_path = os.path.join("uploads", uploaded_file.name)
  226. with open(file_path, "wb") as f:
  227. f.write(uploaded_file.getbuffer())
  228.  
  229. ## Extract text from PDF
  230. full_text = extract_text_from_pdf(file_path)
  231. paragraphs = full_text.split('\n')
  232.  
  233.  
  234. # Summarize each paragraph
  235. st.markdown("<h2>Summarized Text:</h2>", unsafe_allow_html=True)
  236.  
  237. # device = setup_mps_device()
  238.  
  239. for i in range(min(3, len(paragraphs))):
  240. print(i)
  241. if paragraphs[i].strip():
  242. summary = summarize_text(paragraphs[i])
  243.  
  244. image = image_generation(summary)
  245.  
  246. # Create two columns
  247. col1, col2 = st.columns([3, 1])
  248.  
  249. #Summary for left column
  250. with col1:
  251. st.markdown(f"<div class='output'>{summary}</div>", unsafe_allow_html=True)
  252.  
  253. #Image for right column
  254. with col2:
  255. if image is not None:
  256. img_byte_arr = io.BytesIO()
  257. image.save(img_byte_arr, format='PNG')
  258. img_byte_arr = img_byte_arr.getvalue()
  259. st.image(img_byte_arr, use_column_width=True)
  260. else:
  261. st.warning(f"Failed to generate image for this summary")
  262.  
  263.  
  264. st.markdown(f"""
  265. <div class="paragraph-details">
  266. <h2>Total paragraphs:</h2>
  267. <p>No of paragraphs: {len(paragraphs)}</p>
  268. </div>
  269. """, unsafe_allow_html=True)
  270.  
  271. st.markdown('</div>', unsafe_allow_html=True)
  272.  
  273. # Input container
  274. user_input = st.text_input("You can start any conversation with me or upload only PDF files", key="user_input")
  275.  
  276. if st.button("Send"):
  277. if user_input:
  278. st.markdown(f"<div class='output'>User: {user_input}</div>", unsafe_allow_html=True)
  279. # Here you would process the user input and generate a response
  280. # For now, we'll just echo the input
  281. st.markdown(f"<div class='output'>Bot: You said: {user_input}</div>", unsafe_allow_html=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement