Advertisement
WareHouseHD

BART summarize long html documents

Oct 8th, 2024
186
0
Never
1
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.44 KB | Source Code | 0 0
  1. from transformers import BartForConditionalGeneration, BartTokenizer
  2. import torch
  3. import re
  4. import html
  5. import unicodedata
  6. from newspaper import Article
  7. from readability import Document
  8. from bs4 import BeautifulSoup
  9. import requests
  10. import math
  11.  
  12.  
  13. # Check if GPU is available
  14. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  15. print(f'Using device: {device}')
  16. # Number of characters to overlap between chunks
  17. overlap = 50
  18. # Batch size for summarization (number of chunks to summarize at once)
  19. batch_size = 25
  20. # Maximum length of the input text for the model
  21. model_max_length = 512
  22. # model tokenizer
  23. tokenizer = None
  24. # model based on selected model name
  25. model = None
  26. # model name
  27. model_name = None
  28.  
  29.  
  30. # this class is used to define the summary length
  31. class SummaryLength:
  32.     SHORT = "short"
  33.     MEDIUM = "medium"
  34.     LONG = "long"
  35.  
  36.  
  37. # this function is used to write the content and overwrite if exist
  38. def overwrite_file(file_path: str, new_content: str):
  39.     """Overwrite the file with the new content.
  40.    
  41.    Args:
  42.        file_path: The path to the file to overwrite.
  43.        new_content: The new content to write to the file.
  44.    """
  45.     # make sure new content is not empty
  46.     if not new_content:
  47.         print("Empty content supplied!")
  48.         return
  49.    
  50.     with open(file_path, 'w', encoding='utf-8') as file:
  51.         file.write(new_content)
  52.  
  53.  
  54. # this function is used to clean the text content from any unnecessary characters and escape sequences
  55. def clean_text(text_content: str):
  56.     """Clean the text content from any unnecessary characters and escape sequences and replace them with their normal text representation.
  57.  
  58.    Args:
  59.        text_content: The text content to clean.
  60.  
  61.    Returns:
  62.        The cleaned text content.
  63.    """
  64.     # Convert HTML escaped characters to their normal text representation
  65.     cleaned_text = html.unescape(text_content)
  66.    
  67.     # Normalize Unicode characters (replace special characters)
  68.     cleaned_text = unicodedata.normalize('NFKC', cleaned_text)
  69.  
  70.     # Clean up remaining whitespace characters (replace non-breaking spaces, zero-width spaces, etc.)
  71.     cleaned_text = cleaned_text.replace('\xa0', ' ')  # Replace non-breaking space with normal space
  72.     cleaned_text = cleaned_text.replace('\u200b', ' ')  # Remove zero-width space
  73.     cleaned_text = cleaned_text.replace('\u200c', ' ')  # Remove zero-width non-joiner
  74.     cleaned_text = cleaned_text.replace('\u200d', ' ')  # Remove zero-width joiner
  75.     cleaned_text = cleaned_text.replace('\uFEFF', ' ')  # Remove zero-width no-break space
  76.     cleaned_text = cleaned_text.replace('\u00a0', ' ')  # Replace non-breaking space with normal space
  77.     cleaned_text = cleaned_text.replace('\u3000', ' ')  # Replace ideographic space with normal space
  78.     cleaned_text = cleaned_text.replace('©', '(c)') # © → "(c)"
  79.     cleaned_text = cleaned_text.replace('®', '(R)') # ® → "(R)"
  80.     cleaned_text = cleaned_text.replace('™', '(TM)') # ™ → "(TM)"
  81.     # Handle smart quotes and other special punctuation
  82.     cleaned_text = cleaned_text.replace('‘', "'").replace('’', "'")  # Curly single quotes to straight single quotes
  83.     cleaned_text = cleaned_text.replace('“', '"').replace('”', '"')  # Curly double quotes to straight double quotes
  84.     cleaned_text = cleaned_text.replace('–', '-').replace('—', '-')  # En dash and em dash to hyphen
  85.     cleaned_text = cleaned_text.replace('…', '...')  # Ellipsis to three dots
  86.  
  87.     # replace multiple tabs with a single tab in spaces
  88.     cleaned_text = re.sub(r'\t+', ' ', cleaned_text)
  89.    
  90.     # Replace multiple spaces with a single space
  91.     cleaned_text = re.sub(r' {2,}', ' ', cleaned_text)
  92.  
  93.     # lines with only whitespace characters should be converted to empty lines
  94.     cleaned_text = re.sub(r'^\s+$', '', cleaned_text, flags=re.MULTILINE)
  95.  
  96.     # Replace all carriage returns (\r) with newlines (\n)
  97.     cleaned_text = re.sub(r'\r\n', '\n', cleaned_text)
  98.  
  99.     # replace multiple newlines with a single newline
  100.     cleaned_text = re.sub(r'\n{3,}', '\n', cleaned_text)
  101.    
  102.     return cleaned_text.strip()
  103.  
  104.  
  105. # this function is used to extract the main html content from the HTML page
  106. def extract_main_html_content(html_content: str):
  107.     """Extract the main html content from the HTML page, the readable part, without the menus, header, footer, ads etc.
  108.  
  109.    Args:
  110.        html_content: The HTML page content to extract the main content from.
  111.    
  112.    Returns:
  113.        The main html content extracted from the HTML page content.
  114.    """
  115.     # If the HTML content is empty, return an empty string
  116.     if not html_content:
  117.         print("Empty HTML content supplied!")
  118.         return ""
  119.    
  120.     try:
  121.         # Try newspaper first, Create an Article object (URL is not needed here since we're using raw HTML)
  122.         article = Article(url="")
  123.  
  124.         # Set the article's HTML content
  125.         article.set_html(html_content)
  126.  
  127.         # Parse the article (this step is necessary to extract information)
  128.         article.parse()
  129.         newspaper_text_main_content = article.text
  130.     except Exception as err:
  131.         print(f"Error parsing article: {err}")
  132.         newspaper_text_main_content = None
  133.  
  134.     if newspaper_text_main_content:
  135.         html_main_content = newspaper_text_main_content.strip()
  136.         print("Newspaper3k worked!")
  137.     else:
  138.         try:
  139.             # If newspaper fails, try Ruby Readability port to Python
  140.             doc = Document(html_content)
  141.             readability_html_main_content = doc.summary(html_partial=True)
  142.         except Exception as err:
  143.             print(f"Error parsing article with readability: {err}")
  144.             readability_html_main_content = None
  145.  
  146.         if readability_html_main_content:
  147.             html_main_content = readability_html_main_content.strip()
  148.             print("Readability worked!")
  149.         else:
  150.             # If both fail, try a fallback approach
  151.             print(f"Error parsing article with both approaches. Exiting.")
  152.             exit(1)
  153.  
  154.     # Clean the HTML content
  155.     cleaned_html_main_content = clean_text(html_main_content)
  156.  
  157.     return cleaned_html_main_content
  158.  
  159.  
  160. # this function is used to convert HTML content to text content
  161. def html_to_text(html_content: str):
  162.     """Convert HTML content to text content.
  163.    
  164.    Args:
  165.        html_content: The HTML content to convert to text content.
  166.    
  167.    Returns:
  168.        The text content extracted from the HTML content.
  169.    """
  170.     # If the HTML content is empty, return an empty string
  171.     if not html_content:
  172.         print("Empty HTML content supplied!")
  173.         return ""
  174.  
  175.     # extract the text content from the HTML content using BeautifulSoup and lxml parser
  176.     soup = BeautifulSoup(html_content, "lxml")
  177.     text_content = soup.get_text().strip()
  178.     # clean the text content from any unnecessary characters and escape sequences
  179.     cleaned_text_content = clean_text(text_content)
  180.  
  181.     return cleaned_text_content
  182.  
  183.  
  184. def get_url_html(url):
  185.     """Get the HTML content of the page from the given URL.
  186.    
  187.    Args:
  188.        url: The URL of the page to fetch.
  189.        
  190.    Returns:
  191.        The HTML content of the page as a string.
  192.    """
  193.     session = requests.Session()
  194.     headers = {
  195.         'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.212 Safari/537.36',
  196.         'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
  197.         'Referer': 'https://www.google.com/',
  198.         'Accept-Language': 'en-US,en;q=0.5'
  199.     }
  200.     try:
  201.         response = session.get(url, headers=headers)
  202.         response.raise_for_status()  # Raise an error for bad responses
  203.         return response.text
  204.     except Exception as e:
  205.         print(f"Error fetching URL {url}: {e}")
  206.         return None
  207.  
  208.  
  209. # Function to initialize the BART model and tokenizer
  210. def init_bart():
  211.     """Initialize the BART model and tokenizer."""
  212.     global model, tokenizer,device,model_max_length,model_name
  213.     # Load the BART model and tokenizer
  214.     model_name = 'facebook/bart-large-cnn'
  215.     model_max_length = 1024
  216.     tokenizer = BartTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
  217.     model = BartForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
  218.  
  219.  
  220. # Function to calculate the minimum and maximum length of the output summary
  221. def calc_min_max_length(original_num_tokens: int, summary_length: str):
  222.     """
  223.    Calculate the minimum and maximum length of the output summary based on the original text length.
  224.  
  225.    Args:
  226.        original_num_tokens: The number of tokens of the original text.
  227.        summary_length: Desired length of the summary ('short', 'medium', 'long').
  228.  
  229.    Returns:
  230.        The minimum and maximum length of the output summary.
  231.    """
  232.     if summary_length == SummaryLength.SHORT:
  233.         min_length = min(40, original_num_tokens * 0.1)
  234.         max_length = min(150, original_num_tokens * 0.15)
  235.     elif summary_length == SummaryLength.MEDIUM:
  236.         min_length = min(150, original_num_tokens * 0.2)
  237.         max_length = min(250, original_num_tokens * 0.35)
  238.     elif summary_length == SummaryLength.LONG:
  239.         min_length = min(250, original_num_tokens * 0.4)
  240.         max_length = min(450, original_num_tokens * 0.5)
  241.  
  242.     # make sure min_length and max_length are integers
  243.     min_length = int(min_length)
  244.     max_length = int(max_length)
  245.  
  246.     # Ensure that the minimum length is at least 30
  247.     min_length = max(30, min_length)
  248.  
  249.     # Ensure that the maximum length is at least the minimum length
  250.     max_length = max(min_length, max_length)
  251.  
  252.     return min_length, max_length
  253.  
  254.  
  255. # Function to summarize the text content in chunks
  256. def summarize_text(cleaned_text_content: str, summary_length: str, chunk_size: int):
  257.     """
  258.    Summarizes the cleaned text content using the BART model.
  259.  
  260.    Args:
  261.        cleaned_text_content: The text content to summarize.
  262.        summary_length: Desired length of the summary ('short', 'medium', 'long').
  263.        chunk_size: The number of tokens in each chunk (including overlap) for summarization, must be less or equal to the model's max length.
  264.  
  265.    Returns:
  266.        The summarized text content.
  267.    """
  268.     global model_max_length, overlap, batch_size, tokenizer
  269.  
  270.     # Initialize the summary token length
  271.     num_chunks = 1
  272.  
  273.     # Tokenize the input text
  274.     tokens = tokenizer.encode(cleaned_text_content, return_tensors='pt')
  275.  
  276.     # count the token length of the text
  277.     original_num_tokens = tokens.shape[1]
  278.     print(f'Original tokens length: {original_num_tokens}')
  279.  
  280.     # calculate the min & max summary length based on the TOTAL text length
  281.     output_min_length, output_max_length = calc_min_max_length(original_num_tokens, summary_length)
  282.  
  283.     # Directly summarize if content tokens length is small enough
  284.     if original_num_tokens <= model_max_length:
  285.         print('Summarizing single chunk...')
  286.         combined_summary = summarize_single_chunk(tokens=tokens, output_min_length=output_min_length, output_max_length=output_max_length)
  287.         combined_summary = combined_summary[0]
  288.     else:
  289.         # Initialize summary variable
  290.         final_summary = []
  291.  
  292.         # Calculate the number of total chunks and batches we will have in order to summarize the full text
  293.         num_chunks = (original_num_tokens + chunk_size - 1) // chunk_size
  294.         num_batches = math.ceil(num_chunks / batch_size)
  295.         print(f'Number of chunks: {num_chunks}, Number of batches: {num_batches}')
  296.  
  297.         # Prepare chunks, the original text is split into chunks of size chunk_size and all chunks are stored in this list
  298.         # there is an overlap of 'overlap' tokens between each chunk so that the model can summarize the text more effectively
  299.         chunks = []
  300.         for i in range(num_chunks):
  301.             start_index = i * chunk_size
  302.             end_index = min(start_index + chunk_size, original_num_tokens)
  303.             chunk = tokens[:, start_index:end_index]
  304.             # if chunk size is less than the model max length, pad it with padding tokens
  305.             if chunk.shape[1] < model_max_length:
  306.                 padding_length = model_max_length - chunk.shape[1]
  307.                 chunk = torch.cat([chunk, torch.zeros((1, padding_length), dtype=torch.long)], dim=1)
  308.             chunks.append(chunk)
  309.  
  310.         # calculate the min & max summary length based on the CHUNK text length
  311.         chunk_output_min_length, chunk_output_max_length = calc_min_max_length(chunk.shape[1], summary_length)
  312.  
  313.         # Summarize each chunk and combine the summarized chunks into a single text
  314.         for i in range(0, len(chunks), batch_size):
  315.             # Get the chunks for the current batch
  316.             batch_chunks = chunks[i:i + batch_size]            
  317.             # Combine the batch of chunks into a single tensor
  318.             batched_input = torch.cat(batch_chunks, dim=0)
  319.             summarized_chunks = summarize_single_chunk(tokens=batched_input, output_min_length=chunk_output_min_length, output_max_length=chunk_output_max_length)
  320.             # combine the summarized chunks into the final summary
  321.             final_summary.extend(summarized_chunks)
  322.  
  323.         # Combine all summarized chunks into a single text
  324.         combined_summary = ' '.join(final_summary)
  325.         print(f'Summary tokens length: {len(tokenizer.tokenize(combined_summary))}')
  326.  
  327.     return combined_summary
  328.  
  329.  
  330. # Function to summarize a single chunk of text
  331. def summarize_single_chunk(tokens: torch.Tensor, output_min_length: int, output_max_length: int):
  332.     """
  333.    Summarizes a single chunk of text tokens.
  334.  
  335.    Args:
  336.        tokens: The tokenized input text as a PyTorch tensor.
  337.        output_min_length: Minimum length of the output summary.
  338.        output_max_length: Maximum length of the output summary.
  339.  
  340.    Returns:
  341.        The summarized text as string.
  342.    """
  343.     global model, tokenizer, device, model_max_length
  344.    
  345.     # Tokenize the input text
  346.     inputs = tokens.to(device)
  347.  
  348.     # Generate summary using the model and torch autocast (dynamic mixed precision)
  349.     with torch.no_grad():
  350.         with torch.autocast(device_type=device, dtype=torch.float16, enabled=True):
  351.             summary_ids = model.generate(
  352.                 inputs,
  353.                 max_length=output_max_length,  # The maximum length of the output summary
  354.                 min_length=output_min_length,  # The minimum length of the output summary
  355.                 num_beams=4, # The number of top-scoring sequences to consider
  356.                 early_stopping=True,  # Stop the beam search when at least num_beams sentences are finished per batch for all batch indices
  357.                 length_penalty=1.2,  # Allow the model to use more content from the input
  358.                 repetition_penalty=2.0,  # Discourage repetition
  359.                 no_repeat_ngram_size=4,  # Avoid repetition of 4 words combination more than once
  360.             )
  361.  
  362.     # Decode the summary
  363.     summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
  364.  
  365.     return summary
  366.  
  367.  
  368. # Main function to summarize the text content of a URL
  369. if __name__ == '__main__':
  370.     url = "https://www.atltranslate.com/ai/blog/text-summarization-tips"
  371.  
  372.     try:
  373.         print(f"Fetching content from: {url}")
  374.         html_content = get_url_html(url)
  375.     except Exception as e:
  376.         print(f"Error fetching URL {url}: {e}")
  377.         html_content = None
  378.  
  379.     # check if html content is None or empty
  380.     if html_content:
  381.         # save the fetched content to a file
  382.         overwrite_file('test_downloaded.html', html_content)
  383.  
  384.         # Extract the main text content from the HTML content
  385.         cleaned_html_content = extract_main_html_content(html_content)
  386.         cleaned_text_content = html_to_text(cleaned_html_content)
  387.  
  388.         # Save the cleaned text content to a new file
  389.         overwrite_file('test_cleaned.txt', cleaned_text_content)
  390.  
  391.         # Check if there is any content to summarize
  392.         if cleaned_text_content:
  393.             init_bart()
  394.             summary = summarize_text(cleaned_text_content, summary_length=SummaryLength.MEDIUM, chunk_size=model_max_length)
  395.             summary = clean_text(summary)
  396.  
  397.             # Save the summary to a new file
  398.             overwrite_file('test_summary.txt', summary)
  399.             print(f"Summary saved to 'test_summary.txt'")
  400.         else:
  401.             print("No content remaining after extraction. nothing to summarize.")
  402.  
Advertisement
Comments
  • WareHouseHD
    264 days
    # text 0.23 KB | 0 0
    1. beautifulsoup4==4.12.3
    2. bitsandbytes==0.43.3
    3. markdownify==0.13.1
    4. newspaper3k==0.2.8
    5. readability_lxml==0.8.1
    6. Requests==2.32.3
    7. torch==2.4.1+cu124
    8. transformers==4.45.0.dev0
    9.  
    10. run:
    11. import nltk
    12. nltk.download('punkt_tab')
    13.  
Add Comment
Please, Sign In to add comment
Advertisement