Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import random
- import math
- class Bandit:
- def __init__(self, n=10, exp_avg=0.0, exp_var=1.0, var=1.0):
- self.n = n
- self.exp_avg = exp_avg
- self.exp_var = exp_var
- self.var = var
- self.p = [random.normalvariate(exp_avg, exp_var) for i in range(n)]
- def pull(self, arm):
- return random.normalvariate(self.p[arm], self.var)
- class GreedyMethod:
- def __init__(self, n=10, e=0.0):
- self.n = n
- self.e = e
- self.times = [0] * n
- self.values = [0] * n
- def select(self):
- if random.random() < self.e:
- return random.randint(0, self.n - 1)
- else:
- return self.values.index(max(self.values))
- def reflect(self, arm, value):
- self.times[arm] += 1
- self.values[arm] += (value - self.values[arm]) / self.times[arm]
- class SoftmaxMethod:
- def __init__(self, n=10, t=0.2):
- self.n = n
- self.t = t
- self.times = [0] * n
- self.values = [0] * n
- def select(self):
- e = list(map(lambda v: math.exp(v / self.t), self.values))
- e_sum = sum(e)
- rand = random.random()
- for i in range(self.n):
- if rand < e[i] / e_sum:
- return i
- else:
- rand -= e[i] / e_sum
- def reflect(self, arm, value):
- self.times[arm] += 1
- self.values[arm] += (value - self.values[arm]) / self.times[arm]
- # Number of loop
- count = 1000
- # Init bandit
- bandit = Bandit(10)
- # Greedy Method
- greedy = GreedyMethod(10)
- for i in range(count):
- selected_arm = greedy.select()
- bandit_value = bandit.pull(selected_arm)
- greedy.reflect(selected_arm, bandit_value)
- # Epsilon greedy method
- epsilon_greedy = GreedyMethod(10, 0.1)
- for i in range(count):
- selected_arm = epsilon_greedy.select()
- bandit_value = bandit.pull(selected_arm)
- epsilon_greedy.reflect(selected_arm, bandit_value)
- # SoftmaxMethod
- softmax = SoftmaxMethod(10, 0.4)
- for i in range(count):
- selected_arm = softmax.select()
- bandit_value = bandit.pull(selected_arm)
- softmax.reflect(selected_arm, bandit_value)
- print("greedy max: ", greedy.values.index(max(greedy.values)), max(greedy.values))
- print("epsilon greedy max: ", epsilon_greedy.values.index(max(epsilon_greedy.values)), max(epsilon_greedy.values))
- print("softmax max: ", softmax.values.index(max(softmax.values)), max(softmax.values))
- print("solution: ", bandit.p.index(max(bandit.p)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement