Advertisement
Guest User

top algo

a guest
Nov 23rd, 2024
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.95 KB | None | 0 0
  1. """
  2. Top AEW Feed Generator with Caching
  3.  
  4. Same ranking algorithm as before, but with added caching layer to improve performance.
  5. Cache invalidates after 5 minutes or when blocked DIDs change.
  6. """
  7.  
  8. from datetime import datetime, timedelta, UTC
  9. from typing import Optional, List, Dict, Tuple
  10. import math
  11. from functools import lru_cache
  12. import json
  13. from server import config
  14. from server.database import Post, AEWNetwork, db
  15. from server.logger import logger
  16. from server.admin_blocks import get_blocked_dids
  17. from peewee import fn, SQL
  18.  
  19. # Feed identifier
  20. uri = f"at://did:plc:5kk6tru4mdazgfmgyflvz3k2/app.bsky.feed.generator/top-aew"
  21. CURSOR_EOF = 'eof'
  22. CACHE_DURATION = 300  # Cache duration in seconds (5 minutes)
  23.  
  24. # Enhanced cache structure
  25. _feed_cache = {
  26.     'timestamp': None,  # When the cache was last updated
  27.     'feeds': {},       # Dict of cursor -> feed data
  28.     'blocked_hash': None,  # Hash of blocked DIDs for cache invalidation
  29. }
  30.  
  31. def get_blocked_hash(blocked_dids: set) -> str:
  32.     """Create a hash of blocked DIDs for cache invalidation"""
  33.     return hash(frozenset(blocked_dids))
  34.  
  35. def is_cache_valid() -> bool:
  36.     """Check if cache is still valid"""
  37.     if not _feed_cache['timestamp']:
  38.         return False
  39.        
  40.     # Check cache age
  41.     age = (datetime.now(UTC) - _feed_cache['timestamp']).total_seconds()
  42.     if age > CACHE_DURATION:
  43.         logger.info("Cache expired due to age")
  44.         return False
  45.    
  46.     # Check if blocked DIDs have changed
  47.     current_blocked = get_blocked_dids()
  48.     current_hash = get_blocked_hash(current_blocked)
  49.     if current_hash != _feed_cache['blocked_hash']:
  50.         logger.info("Cache invalidated due to blocks change")
  51.         return False
  52.        
  53.     return True
  54.  
  55. def get_cached_feed(cursor: Optional[str], limit: int) -> Optional[dict]:
  56.     """Get feed from cache if available"""
  57.     if not is_cache_valid():
  58.         return None
  59.        
  60.     cache_key = f"{cursor}:{limit}"
  61.     if cache_key in _feed_cache['feeds']:
  62.         logger.info("Serving response from cache")
  63.         return _feed_cache['feeds'][cache_key]
  64.    
  65.     return None
  66.  
  67. def update_cache(cursor: Optional[str], limit: int, feed_data: dict) -> None:
  68.     """Update cache with new feed data"""
  69.     try:
  70.         # Initialize cache if needed
  71.         if not _feed_cache['timestamp']:
  72.             _feed_cache['feeds'] = {}
  73.            
  74.         # Update cache metadata
  75.         _feed_cache['timestamp'] = datetime.now(UTC)
  76.         _feed_cache['blocked_hash'] = get_blocked_hash(get_blocked_dids())
  77.        
  78.         # Store feed data
  79.         cache_key = f"{cursor}:{limit}"
  80.         _feed_cache['feeds'][cache_key] = feed_data
  81.        
  82.         # Cleanup old entries if cache is getting too large
  83.         if len(_feed_cache['feeds']) > 100:  # Arbitrary limit
  84.             oldest_keys = sorted(_feed_cache['feeds'].keys())[:-50]  # Keep newest 50
  85.             for key in oldest_keys:
  86.                 del _feed_cache['feeds'][key]
  87.                
  88.         logger.info(f"Updated cache with new feed data for cursor: {cursor}")
  89.        
  90.     except Exception as e:
  91.         logger.error(f"Error updating cache: {e}")
  92.  
  93. def ensure_tz_aware(dt: datetime | str) -> datetime:
  94.     """
  95.    Ensure datetime is timezone aware, handling both datetime and string inputs.
  96.    """
  97.     if isinstance(dt, str):
  98.         try:
  99.             # Try parsing ISO format first (preferred)
  100.             dt = datetime.fromisoformat(dt.replace('Z', '+00:00'))
  101.         except ValueError:
  102.             try:
  103.                 # Fallback to basic format
  104.                 dt = datetime.strptime(dt.split('.')[0], "%Y-%m-%d %H:%M:%S")
  105.             except Exception as e:
  106.                 logger.error(f"Error parsing datetime '{dt}': {e}")
  107.                 raise
  108.    
  109.     if dt and dt.tzinfo is None:
  110.         return dt.replace(tzinfo=UTC)
  111.     return dt
  112.  
  113. @lru_cache(maxsize=1000)
  114. def is_network_member(author_did: str) -> bool:
  115.     """Check if an author is part of the AEW network"""
  116.     try:
  117.         return AEWNetwork.select().where(
  118.             (AEWNetwork.did == author_did) &
  119.             (AEWNetwork.is_followed_by_aew == True)
  120.         ).exists()
  121.     except Exception:
  122.         return False
  123.  
  124. def calculate_engagement_score(post: Post, age_hours: float) -> Tuple[float, Dict[str, float]]:
  125.     """Calculate engagement score using HN-inspired algorithm with AEW customizations."""
  126.     try:
  127.         # Get engagement metrics
  128.         likes = getattr(post, 'likeCount', 0) or 0
  129.         reposts = getattr(post, 'repostCount', 0) or 0
  130.         replies = getattr(post, 'replyCount', 0) or 0
  131.        
  132.         # Base engagement calculations with HN-style power dampening
  133.         base_score = (likes * 1.0) + (reposts * 1.5) + (replies * 1.0)
  134.         if base_score > 1:
  135.             base_score = math.pow(base_score - 1, 0.8)  # HN's power dampening
  136.        
  137.         # Calculate time decay (HN-style)
  138.         time_base = 2.0  # 120 minutes in hours
  139.         decay_factor = age_hours / time_base
  140.         time_multiplier = 1.0 / ((decay_factor + 1) ** 1.8)  # HN's gravity
  141.        
  142.         # Network boost (reward AEW talent/official posts)
  143.         network_multiplier = 1.1 if is_network_member(post.author) else 1.0
  144.        
  145.         # Engagement balance boost (reward diverse engagement)
  146.         has_balance = (likes > 0 and reposts > 0 and replies > 0)
  147.         balance_multiplier = 1.1 if has_balance else 1.0
  148.        
  149.         # Viral boost for highly engaged content
  150.         viral_multiplier = 1.2 if (reposts > 10 or likes > 50) else 1.0
  151.        
  152.         # Calculate final score
  153.         score = base_score * time_multiplier * network_multiplier * balance_multiplier * viral_multiplier
  154.        
  155.         # Create detailed breakdown for logging
  156.         details = {
  157.             'base_score': base_score,
  158.             'age_hours': age_hours,
  159.             'time_multiplier': time_multiplier,
  160.             'network_multiplier': network_multiplier,
  161.             'balance_multiplier': balance_multiplier,
  162.             'viral_multiplier': viral_multiplier,
  163.             'final_score': score,
  164.             'likes': likes,
  165.             'reposts': reposts,
  166.             'replies': replies,
  167.             'is_network': is_network_member(post.author)
  168.         }
  169.        
  170.         return score, details
  171.        
  172.     except Exception as e:
  173.         logger.error(f"Error calculating score for post {post.cid}: {e}")
  174.         return 0.0, {}
  175.  
  176. def get_filtered_posts(blocked_dids: set, cursor_time: Optional[datetime] = None, cursor_cid: Optional[str] = None) -> List[Post]:
  177.     """Get posts filtered by time window and blocked users."""
  178.     try:
  179.         # Build base query
  180.         cutoff_time = datetime.now(UTC) - timedelta(hours=24)
  181.        
  182.         query = (Post
  183.             .select()
  184.             .where(
  185.                 (Post.indexed_at >= cutoff_time) &
  186.                 ~(Post.author << list(blocked_dids)) &
  187.                 (Post.reply_parent.is_null(True))
  188.             )
  189.             .order_by(Post.indexed_at.desc())
  190.         )
  191.  
  192.         # Apply cursor for pagination
  193.         if cursor_time and cursor_cid:
  194.             query = query.where(
  195.                 (Post.indexed_at < cursor_time) |
  196.                 ((Post.indexed_at == cursor_time) & (Post.cid < cursor_cid))
  197.             )
  198.            
  199.         return list(query)
  200.        
  201.     except Exception as e:
  202.         logger.error(f"Error in get_filtered_posts: {e}", exc_info=True)
  203.         return []
  204.  
  205. def handler(cursor: Optional[str], limit: int) -> dict:
  206.     """Main feed handler with caching."""
  207.     try:
  208.         # Check cache first
  209.         cached_feed = get_cached_feed(cursor, limit)
  210.         if cached_feed:
  211.             return cached_feed
  212.            
  213.         logger.info(f"Feed request - cursor: {cursor}, limit: {limit}")
  214.        
  215.         blocked_dids = get_blocked_dids()
  216.         logger.info(f"Using {len(blocked_dids)} blocked DIDs")
  217.        
  218.         # Handle EOF cursor
  219.         if cursor == CURSOR_EOF:
  220.             logger.info("Received EOF cursor")
  221.             return {'cursor': CURSOR_EOF, 'feed': []}
  222.            
  223.         # Parse cursor for pagination
  224.         cursor_time = None
  225.         cursor_cid = None
  226.        
  227.         if cursor:
  228.             try:
  229.                 cursor_parts = cursor.split('::')
  230.                 if len(cursor_parts) != 2:
  231.                     raise ValueError('Malformed cursor')
  232.                    
  233.                 timestamp, cursor_cid = cursor_parts
  234.                 cursor_time = datetime.fromtimestamp(int(timestamp) / 1000, UTC)
  235.                 logger.info(f"Parsed cursor - time: {cursor_time}, cid: {cursor_cid}")
  236.                
  237.             except Exception as e:
  238.                 logger.error(f"Error parsing cursor: {e}")
  239.                 raise ValueError('Invalid cursor format')
  240.        
  241.         # Get posts with pagination
  242.         posts = get_filtered_posts(blocked_dids, cursor_time, cursor_cid)
  243.         logger.info(f"Retrieved {len(posts)} posts after filtering")
  244.        
  245.         if not posts:
  246.             logger.info("No posts found")
  247.             return {'cursor': CURSOR_EOF, 'feed': []}
  248.            
  249.         # Calculate scores and sort
  250.         scored_posts = []
  251.         current_time = datetime.now(UTC)
  252.        
  253.         # Log detailed scoring information
  254.         logger.info("\nScoring breakdown for posts:")
  255.         for post in posts:
  256.             try:
  257.                 age_hours = (current_time - ensure_tz_aware(post.indexed_at)).total_seconds() / 3600.0
  258.                 score, details = calculate_engagement_score(post, age_hours)
  259.                 scored_posts.append((post, score))
  260.                
  261.                 # Detailed logging for debugging
  262.                 logger.info(
  263.                     f"\nPost {post.cid}:\n"
  264.                     f"- Age: {details['age_hours']:.1f} hours\n"
  265.                     f"- Engagement: {details['likes']}? {details['reposts']}? {details['replies']}?\n"
  266.                     f"- Base Score: {details['base_score']:.1f}\n"
  267.                     f"- Time Mult: {details['time_multiplier']:.3f}\n"
  268.                     f"- Network Mult: {details['network_multiplier']:.1f}\n"
  269.                     f"- Balance Mult: {details['balance_multiplier']:.1f}\n"
  270.                     f"- Viral Mult: {details['viral_multiplier']:.1f}\n"
  271.                     f"- Final Score: {details['final_score']:.1f}\n"
  272.                     f"- Network Member: {'Yes' if details['is_network'] else 'No'}"
  273.                 )
  274.             except Exception as e:
  275.                 logger.error(f"Error scoring post {post.cid}: {e}")
  276.                 continue
  277.        
  278.         # Sort by score and limit results
  279.         sorted_posts = sorted(scored_posts, key=lambda x: (-x[1], ensure_tz_aware(x[0].indexed_at)))
  280.         result_posts = [post for post, _ in sorted_posts[:limit]]
  281.        
  282.         # Log selected posts
  283.         logger.info("\nFinal post selection:")
  284.         for i, post in enumerate(result_posts, 1):
  285.             logger.info(
  286.                 f"{i}. Selected post {post.cid}: "
  287.                 f"engagement={post.likeCount}? {post.repostCount}? {post.replyCount}? "
  288.                 f"time={post.indexed_at} "
  289.                 f"network={'Yes' if is_network_member(post.author) else 'No'}"
  290.             )
  291.        
  292.         # Generate feed entries
  293.         feed = [{'post': post.uri} for post in result_posts]
  294.        
  295.         # Generate cursor for next page
  296.         next_cursor = CURSOR_EOF
  297.         if result_posts:
  298.             last_post = result_posts[-1]
  299.             next_cursor = f"{int(ensure_tz_aware(last_post.indexed_at).timestamp() * 1000)}::{last_post.cid}"
  300.             logger.info(f"Generated next cursor: {next_cursor}")
  301.        
  302.         # Store result in cache before returning
  303.         result = {
  304.             'cursor': next_cursor,
  305.             'feed': feed
  306.         }
  307.         update_cache(cursor, limit, result)
  308.         return result
  309.        
  310.     except Exception as e:
  311.         logger.error(f"Error handling feed request: {e}", exc_info=True)
  312.         return {'cursor': CURSOR_EOF, 'feed': []}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement