Advertisement
sourav8256

Restless Bandit problem

Jan 31st, 2023
600
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.62 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3.  
  4. class RestlessBandit:
  5.     def __init__(self, n_arms, alpha=1.0):
  6.         self.n_arms = n_arms
  7.         self.alpha = alpha
  8.         self.t = 0
  9.         self.rewards = [0.0] * n_arms
  10.         self.plays = [0] * n_arms
  11.  
  12.     def play(self):
  13.         self.t += 1
  14.         ucb = [self.rewards[i] + self.alpha * np.sqrt(np.log(self.t) / self.plays[i]) for i in range(self.n_arms)]
  15.         return np.argmax(ucb)
  16.  
  17.     def update(self, arm, reward):
  18.         self.plays[arm] += 1
  19.         self.rewards[arm] = ((self.plays[arm] - 1) * self.rewards[arm] + reward) / self.plays[arm]
  20.  
  21. # Define the number of arms
  22. n_arms = 5
  23.  
  24. # Create an instance of the RestlessBandit class
  25. bandit = RestlessBandit(n_arms)
  26.  
  27. # Define the number of rounds to play
  28. n_rounds = 5000
  29.  
  30. # Define the true expected rewards for each arm
  31. true_rewards = [0.1, 0.2, 0.3, 0.4, 0.5]
  32.  
  33. # Store the results for each round
  34. results = []
  35.  
  36.  
  37.  
  38.  
  39. # Play the bandit for each round
  40. for i in range(n_rounds):
  41.  
  42.  
  43.  
  44.     # Choose an arm to play
  45.     arm = bandit.play()
  46.    
  47.     # Simulate the reward for the chosen arm
  48.     reward = np.random.normal(true_rewards[arm], 1.0)
  49.    
  50.     # Update the rewards for the chosen arm
  51.     bandit.update(arm, reward)
  52.    
  53.  
  54.  
  55.  
  56.  
  57.  
  58.     # Update the true expected rewards for each arm
  59.     true_rewards = [reward + np.random.normal(0, 0.01) for reward in true_rewards]
  60.    
  61.     # Store the results for each round
  62.     results.append(reward)
  63.  
  64.  
  65.  
  66.  
  67.  
  68.  
  69.  
  70. # Plot the results
  71. plt.plot(results)
  72. plt.xlabel("Round")
  73. plt.ylabel("Reward")
  74. plt.title("Restless Bandit Problem")
  75. plt.show()
  76.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement