Advertisement
Guest User

Untitled

a guest
Aug 26th, 2019
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.61 KB | None | 0 0
  1. import random
  2.  
  3.  
  4. class Environment:
  5.     def __init__(self, c_1, c_2):
  6.         self.c_1 = c_1
  7.         self.c_2 = c_2
  8.  
  9.     def penalty(self, action):
  10.         if action == 1:
  11.             if random.random() <= self.c_1:
  12.                 return True
  13.             else:
  14.                 return False
  15.         elif action == 2:
  16.             if random.random() <= self.c_2:
  17.                 return True
  18.             else:
  19.                 return False
  20.  
  21.  
  22. class Tsetlin:
  23.     def __init__(self, n):
  24.         # n is the number of states per action
  25.         self.n = n
  26.  
  27.         # Initial state selected randomly
  28.         self.state = random.choice([self.n, self.n+1])
  29.  
  30.     def reward(self):
  31.         if self.state <= self.n and self.state > 1:
  32.             self.state -= 1
  33.         elif self.state > self.n and self.state < 2*self.n:
  34.             self.state += 1
  35.  
  36.     def penalize(self):
  37.         if self.state <= self.n:
  38.             self.state += 1
  39.         elif self.state > self.n:
  40.             self.state -= 1
  41.  
  42.     def makeDecision(self):
  43.         if self.state <= self.n:
  44.             return 1
  45.         else:
  46.             return 2
  47.  
  48.  
  49. env = Environment(0.1, 0.3)
  50.  
  51. la = Tsetlin(3)
  52.  
  53. action_count = [0, 0]
  54.  
  55. for i in range(500):
  56.     action = la.makeDecision()
  57.     action_count[action - 1] += 1
  58.     penalty = env.penalty(action)
  59.     print("State:", la.state, "Action:", action)
  60.  
  61.     if penalty:
  62.         print("Penalty")
  63.         la.penalize()
  64.     else:
  65.         print("Reward")
  66.         la.reward()
  67.  
  68.     print("New state: ", la.state, "\n")
  69.  
  70.  
  71. print("#Action 1: ", action_count[0], " #Action 2: ", action_count[1])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement