Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import torch
- from diffusers import StableDiffusionPipeline
- import PyPDF2
- from PIL import Image
- import io
- import streamlit as st
- import time
- from functools import wraps
- from requests.exceptions import ChunkedEncodingError
- from transformers import BartTokenizerFast, BartForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
- # Load the saved tokenizer and model for text summarization
- tokenizer = AutoTokenizer.from_pretrained("best_saved_token_10")
- model = AutoModelForSeq2SeqLM.from_pretrained("best_saved_model_10")
- # Function to extract text from a PDF file
- def extract_text_from_pdf(pdf_path):
- ''' Extract text from a PDF file '''
- with open(pdf_path, 'rb') as file:
- reader = PyPDF2.PdfReader(file)
- text = ''
- for page in reader.pages:
- page_text = page.extract_text()
- if page_text:
- text += page_text + '\n'
- return text
- # Function to summarize a given text
- def summarize_text(text, max_length=100):
- ''' Summarize the given text using the loaded model '''
- input_ids = tokenizer.encode(text, return_tensors='pt')
- output_ids = model.generate(input_ids, max_length=1000, num_beams=4, length_penalty=2.0, early_stopping=True)
- summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
- return summary
- # Load the model for text to image generation
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
- model_id = 'saved_image_model'
- def setup_mps_device():
- if torch.backends.mps.is_available():
- mps_device = torch.device("mps")
- return mps_device
- else:
- st.warning("MPS device is not available.")
- return torch.device("cpu")
- def retry_on_chunked_encoding_error(max_retries=3, delay=1):
- def decorator(func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- retries = 0
- while retries < max_retries:
- try:
- return func(*args, **kwargs)
- except ChunkedEncodingError as e:
- retries += 1
- if retries == max_retries:
- st.warning(f"Failed to generate image after {max_retries} attempts.")
- return None
- time.sleep(delay * retries) # Exponential backoff
- return None # This line should never be reached
- return wrapper
- return decorator
- @retry_on_chunked_encoding_error(max_retries=3, delay=1)
- # Function to generate an image from a given text
- def image_generation(text):
- ''' Generate an image from the given text using the loaded model '''
- pipeline = DiffusionPipeline.from_pretrained(model_id)
- device = setup_mps_device()
- pipeline.to(device)
- with torch.inference_mode():
- image = pipeline(text,
- height = 1024,
- width = 1024,
- guidance_scale= 7.0,
- num_inference_steps=50).images[0]
- return image
- # class LocalImageGenerator:
- # def __init__(self, model_path):
- # self.model_path = os.path.abspath(model_path)
- # self.pipeline = None
- # self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
- # def load_model(self):
- # if self.pipeline is None:
- # if not os.path.exists(self.model_path):
- # raise FileNotFoundError(f"Model path does not exist: {self.model_path}")
- # try:
- # self.pipeline = StableDiffusionPipeline.from_pretrained(
- # self.model_path,
- # torch_dtype=torch.float32 if self.device == "mps" else torch.float16,
- # local_files_only=True
- # )
- # self.pipeline = self.pipeline.to(self.device)
- # except Exception as e:
- # raise Exception(f"Failed to load model from {self.model_path}: {str(e)}")
- # def generate_image(self, prompt, height=128, width=128, num_inference_steps=50, guidance_scale=7.5):
- # if self.pipeline is None:
- # self.load_model()
- # try:
- # with torch.no_grad():
- # image = self.pipeline(prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
- # return image
- # except Exception as e:
- # raise Exception(f"Error during image generation: {str(e)}")
- # # Initialize the generator with the path to your locally saved model
- # image_generator = LocalImageGenerator("saved_image_model_6.5_50")
- # def image_generation(text):
- # try:
- # image = image_generator.generate_image(text)
- # return image
- # except Exception as e:
- # st.error(f"An error occurred during image generation: {str(e)}")
- # return None
- # Streamlit UI
- st.set_page_config(page_title="Easy Read Document")
- # Custom CSS
- st.markdown("""
- <style>
- @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&family=Poppins:wght@700&display=swap');
- @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css');
- :root {
- --bg-color: #f5f5f5;
- --text-color: #333;
- --input-bg: #ffffff;
- --border-color: #e0e0e0;
- --chat-bg: #ffffff;
- --file-details-bg: #f0f0f0;
- }
- .stApp {
- font-family: 'Roboto', sans-serif;
- background-color: var(--background-color);
- color: var(--text-color);
- }
- .header {
- display: flex;
- justify-content: space-between;
- align-items: center;
- padding: 20px;
- background-color: var(--input-bg);
- border-bottom: 1px solid var(--border-color);
- }
- .logo {
- width: 100px;
- }
- .title {
- font-family: 'Poppins', sans-serif;
- font-size: 28px;
- font-weight: 700;
- text-align: center;
- background: linear-gradient(45deg, #007BFF, #00BFFF);
- -webkit-background-clip: text;
- -webkit-text-fill-color: transparent;
- }
- .file-details {
- background-color: var(--file-details-bg);
- padding: 15px;
- border-radius: 5px;
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
- color: var(--text-color);
- margin-top: 20px;
- }
- .stTextInput > div > div > input {
- background-color: var(--input-bg);
- color: var(--text-color);
- }
- .stButton > button {
- background-color: #007BFF;
- color: white;
- }
- .output {
- font-family: Arial, sans-serif;
- white-space: pre-wrap;
- }
- </style>
- """, unsafe_allow_html=True)
- # Header
- # st.markdown("""
- # <div class="header">
- # <img src="static/USW.png" alt="University Logo" class="logo">
- # <h1 class="title">Easy Read Document</h1>
- # </div>
- # """, unsafe_allow_html=True
- st.image("static/USW.png", width=100)
- st.markdown("<h1 style='text-align: center;'>Easy Read Document</h1>", unsafe_allow_html=True)
- # File uploader
- uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
- if uploaded_file is not None:
- # File details
- st.markdown(f"""
- <div class="file-details">
- <h2>Uploaded file details:</h2>
- <p>Filename: {uploaded_file.name}</p>
- </div>
- """, unsafe_allow_html=True)
- # Process and summarize the PDF
- with st.spinner("Processing..."):
- # Save the uploaded file
- file_path = os.path.join("uploads", uploaded_file.name)
- with open(file_path, "wb") as f:
- f.write(uploaded_file.getbuffer())
- ## Extract text from PDF
- full_text = extract_text_from_pdf(file_path)
- paragraphs = full_text.split('\n')
- # Summarize each paragraph
- st.markdown("<h2>Summarized Text:</h2>", unsafe_allow_html=True)
- # device = setup_mps_device()
- for i in range(min(3, len(paragraphs))):
- print(i)
- if paragraphs[i].strip():
- summary = summarize_text(paragraphs[i])
- image = image_generation(summary)
- # Create two columns
- col1, col2 = st.columns([3, 1])
- #Summary for left column
- with col1:
- st.markdown(f"<div class='output'>{summary}</div>", unsafe_allow_html=True)
- #Image for right column
- with col2:
- if image is not None:
- img_byte_arr = io.BytesIO()
- image.save(img_byte_arr, format='PNG')
- img_byte_arr = img_byte_arr.getvalue()
- st.image(img_byte_arr, use_column_width=True)
- else:
- st.warning(f"Failed to generate image for this summary")
- st.markdown(f"""
- <div class="paragraph-details">
- <h2>Total paragraphs:</h2>
- <p>No of paragraphs: {len(paragraphs)}</p>
- </div>
- """, unsafe_allow_html=True)
- st.markdown('</div>', unsafe_allow_html=True)
- # Input container
- user_input = st.text_input("You can start any conversation with me or upload only PDF files", key="user_input")
- if st.button("Send"):
- if user_input:
- st.markdown(f"<div class='output'>User: {user_input}</div>", unsafe_allow_html=True)
- # Here you would process the user input and generate a response
- # For now, we'll just echo the input
- st.markdown(f"<div class='output'>Bot: You said: {user_input}</div>", unsafe_allow_html=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement