Advertisement
Guest User

Untitled

a guest
Jan 17th, 2018
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.44 KB | None | 0 0
  1. import random
  2. from copy import deepcopy
  3.  
  4. PT = 0.8
  5. PO = 0.8
  6. C = 4
  7.  
  8. RED = 1
  9. GREEN = 2
  10. BLUE = 3
  11. BLACK = 0
  12.  
  13. COLORS = [BLACK, RED, GREEN, BLUE]
  14.  
  15. def weighted_choice(choices):
  16.    total = sum(w for c, w in choices)
  17.    r = random.uniform(0, total)
  18.    upto = 0
  19.    for c, w in choices:
  20.       if upto + w >= r:
  21.          return c
  22.       upto += w
  23.  
  24. class Grid(object):
  25.     def __init__(self, name, heights, colors):
  26.         self.name = name
  27.         self.heights = heights
  28.         self.colors = colors
  29.         self.size = len(heights)
  30.  
  31.     def height(self, state):
  32.         return self.heights[state[0]][state[1]]
  33.  
  34.     def color(self, state):
  35.         return self.colors[state[0]][state[1]]
  36.  
  37.     def get_neighbors(self, state):
  38.         neighbors = [self.left(state), self.right(state), self.up(state), self.down(state)]
  39.         neighbors = [(n, self.height(n)) for n in neighbors if n is not None]
  40.         max_height = max([n[1] for n in neighbors])
  41.         n_max = [n[1] for n in neighbors].count(max_height)
  42.  
  43.         prob = (1 - PT) / len(neighbors)
  44.         return {n[0]:(prob + PT / n_max if (n[1] == max_height) else prob) for n in neighbors}
  45.  
  46.     def get_colors(self, state):
  47.         color = self.color(state)
  48.  
  49.         prob = (1 - PO) / C
  50.         return {c:(PO + prob if c == color else prob) for c in COLORS}
  51.  
  52.     def transition_prob(self, src, dst):
  53.         return self.get_neighbors(src).get(dst, 0)
  54.  
  55.     def emission_prob(self, state, color):
  56.         return self.get_colors(state).get(color, 0)
  57.  
  58.     def get_sequence(self, length):
  59.         states = []
  60.         observations = []
  61.         state = None
  62.         for t in range(length):
  63.             if t == 0:
  64.                 state = (random.randint(0, self.size - 1), random.randint(0, self.size - 1))
  65.             else:
  66.                 state = weighted_choice([(k,v) for k,v in self.get_neighbors(state).iteritems()])
  67.             observation = weighted_choice([(k,v) for k,v in self.get_colors(state).iteritems()])
  68.  
  69.             states.append(state)
  70.             observations.append(observation)
  71.  
  72.         return states, observations
  73.  
  74.     def forward(self, observations):
  75.         alpha = []
  76.         for _ in range(self.size):
  77.             alpha.append([1 / (1.0 * self.size * self.size)] * self.size)
  78.         for row in range(self.size):
  79.             for col in range(self.size):
  80.                 alpha[row][col] *= self.emission_prob((row, col), observations[0])
  81.  
  82.         for idx in range(len(observations)):
  83.             prev_alpha = deepcopy(alpha)
  84.             alpha = []
  85.             for _ in range(self.size):
  86.                 alpha.append([0] * self.size)
  87.  
  88.             for row in range(self.size):
  89.                 for col in range(self.size):
  90.                     for state, prob in self.get_neighbors((row, col)).iteritems():
  91.                         alpha[row][col] += (prev_alpha[state[0]][state[1]] * prob * self.emission_prob((row, col), observations[idx]))
  92.  
  93.         return sum([sum(s) for s in prev_alpha]), prev_alpha
  94.  
  95.  
  96.     def left(self, state):
  97.         if state[1] == 0:
  98.             return None
  99.         return state[0], state[1] - 1
  100.  
  101.     def right(self, state):
  102.         if state[1] == self.size - 1:
  103.             return None
  104.         return state[0], state[1] + 1
  105.  
  106.     def up(self, state):
  107.         if state[0] == 0:
  108.             return None
  109.         return state[0] - 1, state[1]
  110.  
  111.     def down(self, state):
  112.         if state[0] == self.size - 1:
  113.             return None
  114.         return state[0] + 1, state[1]
  115.  
  116.  
  117. def main():
  118.     grid1 = Grid("Grid 1",
  119.                  [[1, 2, 3, 5], [2, 2, 1, 2], [3, 2, 1, 1], [0, 0, 0, 0]],  # elevation
  120.                  [[0, 3, 1, 2], [3, 1, 2, 0], [2, 2, 0, 0], [3, 0, 3, 1]])  # color
  121.     grid2 = Grid("Grid 2",
  122.                  [[0, 0, 1, 1], [2, 1, 0, 2], [1, 0, 0, 2], [4, 4, 3, 3]],  # elevation
  123.                  [[0, 3, 1, 2], [3, 1, 2, 0], [2, 2, 0, 0], [3, 0, 3, 1]])  # color
  124.     grid3 = Grid("Grid 3",
  125.                  [[2, 1, 2, 3], [1, 1, 2, 2], [1, 0, 1, 1], [2, 1, 1, 2]],  # elevation
  126.                  [[2, 3, 1, 0], [1, 3, 3, 1], [0, 2, 0, 2], [2, 1, 1, 2]])  # color
  127.     GRIDS = [grid1, grid2, grid3]
  128.  
  129.     # print grid.get_neighbors((3,1))
  130.     # print grid.get_colors((0, 0))
  131.     # print grid.transition_prob((3, 3), (3, 2))
  132.     # print grid.emission_prob((0, 0), GREEN)
  133.     # print grid.get_sequence(2)
  134.     for grid in GRIDS:
  135.         print grid.forward([0, 0])
  136.         print ''
  137.  
  138. if __name__ == '__main__':
  139.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement