Advertisement
creamygoat

AI: MDP Value Iteration (Naïve)

Nov 20th, 2011
613
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.94 KB | None | 0 0
  1. from types import ListType
  2. from types import FloatType
  3.  
  4. W = [
  5.   [' ', ' ', ' ', 'G'],
  6.   [' ', '#', ' ', 'H'],
  7.   [' ', ' ', ' ', ' ']
  8. ]
  9.  
  10. Cost = 3
  11.  
  12. def RewardOfCell(C):
  13.   Result = -Cost
  14.   if C == 'G':
  15.     Result = 100
  16.   elif C == 'H':
  17.     Result = -100
  18.   return Result
  19.  
  20. Dir_North = 0
  21. Dir_East = 1
  22. Dir_South = 2
  23. Dir_West = 3
  24.  
  25. ActionSet = [Dir_North, Dir_East, Dir_South, Dir_West]
  26.  
  27. DirDeltas = [
  28.   (0, -1), (1, 0), (0, 1), (-1, 0)
  29. ]
  30.  
  31. def Dim(A):
  32.   Z = A
  33.   Result = ()
  34.   while type(Z) == ListType:
  35.     Result += (len(Z),)
  36.     Z = Z[0]
  37.   return Result
  38.  
  39. def NewArrayOfDim(Dim):
  40.   class tRef (object):
  41.     pass
  42.   def Fill(A, k):
  43.     e = Dim[k]
  44.     if k + 1 < len(Dim):
  45.       for i in range(e):
  46.         A.append([])
  47.         Fill(A[i], k + 1)
  48.     else:
  49.       for i in range(e):
  50.         A.append(0)
  51.   if len(Dim) > 0:
  52.     A = []
  53.     Fill(A, 0)
  54.     return A
  55.   else:
  56.     return []
  57.  
  58. class tEnv (object):
  59.   def __init__(self, World, PForward, DiscountFactor):
  60.     self.World = World
  61.     self.PForward = PForward
  62.     self.DiscountFactor = DiscountFactor # 1 => no cost, 0.9 => 10% penalty
  63.  
  64. def CellAt(Env, State):
  65.   return Env.World[State[1]][State[0]]
  66.  
  67. def RewardAt(Env, State):
  68.   return RewardOfCell(CellAt(Env, State))
  69.  
  70. def LeftFrom(Dir):
  71.   return [Dir_West, Dir_North, Dir_East, Dir_South][Dir]
  72. def RightFrom(Dir):
  73.   return [Dir_East, Dir_South, Dir_West, Dir_North][Dir]
  74. def BackFrom(Dir):
  75.   return [Dir_South, Dir_West, Dir_North, Dir_East][Dir]
  76.  
  77. def InitalValueMap(Env):
  78.  Result = NewArrayOfDim(Dim(W))
  79.  return Result
  80.  
  81. def StateAfterAction(Env, State, Action):
  82.   Result = State
  83.   if CellAt(Env, State) in [' ']:
  84.     D = Dim(Env.World)
  85.     Delta = DirDeltas[Action]
  86.     NewState = (
  87.       max(0, min(D[1] - 1, State[0] + Delta[0])),
  88.       max(0, min(D[0] - 1, State[1] + Delta[1]))
  89.     )
  90.     if CellAt(Env, NewState) != '#':
  91.       Result = NewState
  92.     return Result
  93.  
  94. def ActionStates_LR(Env, State, Action):
  95.   PLeft = 0.5 * (1.0 - Env.PForward)
  96.   PRight = 0.5 * (1.0 - Env.PForward)
  97.   Result = []
  98.   if CellAt(Env, State) in [' ']:
  99.     if Env.PForward > 0:
  100.       Result.append((
  101.         Env.PForward, StateAfterAction(Env, State, Action), Action
  102.       ))
  103.     if PLeft > 0:
  104.       AltAction = LeftFrom(Action)
  105.       Result.append((
  106.         PLeft, StateAfterAction(Env, State, AltAction), AltAction
  107.       ))
  108.     if PRight > 0:
  109.       AltAction = RightFrom(Action)
  110.       Result.append((
  111.         PRight, StateAfterAction(Env, State, AltAction), AltAction
  112.       ))
  113.   return Result
  114.  
  115. def ActionStates_B(Env, State, Action):
  116.   PBack = 1.0 - Env.PForward
  117.   Result = []
  118.   if CellAt(Env, State) in [' ']:
  119.     if Env.PForward > 0:
  120.       Result.append((
  121.         Env.PForward, StateAfterAction(Env, State, Action), Action
  122.       ))
  123.     if PBack > 0:
  124.       AltAction = BackFrom(Action)
  125.       Result.append((
  126.         PBack, StateAfterAction(Env, State, AltAction), AltAction
  127.       ))
  128.   return Result
  129.  
  130. def ValueForAction(Env, V, State, Action):
  131.   ASRecs = ActionStates_LR(Env, State, Action)
  132.   Result = 0
  133.   for Probability, NewState, EffectiveAction in ASRecs:
  134.     Result += Probability * V[NewState[1]][NewState[0]]
  135.   Result = Result * Env.DiscountFactor + RewardAt(Env, State)
  136.   return Result
  137.  
  138. def ValueOfState(Env, V, State):
  139.   Result = None
  140.   if CellAt(Env, State) not in ['#']:
  141.     for Action in ActionSet:
  142.       ActionValue = ValueForAction(Env, V, State, Action)
  143.       if (Result == None) or (ActionValue > Result):
  144.         Result = ActionValue
  145.   return Result
  146.  
  147. def UpdateValueAt(Env, V, State):
  148.   x, y = State
  149.   Result = ValueOfState(Env, V, (x, y))
  150.   V[y][x] = Result
  151.   return Result
  152.  
  153. def UpdateValues(Env, V):
  154.   D = Dim(Env.World)
  155.   for y in range(D[0]):
  156.     for x in range(D[1]):
  157.       UpdateValueAt(Env, V, (x, y))
  158.  
  159. def PolicyAt(Env, V, State):
  160.   D = Dim(V)
  161.   x, y = State
  162.   Result = 'X'
  163.   if CellAt(Env, State) in [' ']:
  164.     BestValue = 0
  165.     for Action in ActionSet:
  166.       ActionValue = ValueForAction(Env, V, State, Action)
  167.       if (Result == 'X') or (ActionValue > BestValue):
  168.         BestValue = ActionValue
  169.         Result = 'NESW'[Action]
  170.   return Result
  171.  
  172. def Policy(Env, V):
  173.   D = Dim(V)
  174.   Result = NewArrayOfDim(D)
  175.   for y in range(D[0]):
  176.     for x in range(D[1]):
  177.       Result[y][x] = PolicyAt(Env, V, (x, y))
  178.   return Result
  179.  
  180.  
  181. def PrintNice(A):
  182.   D = Dim(A)
  183.   for y in range(D[0]):
  184.     PrefixStr = '[' if y == 0 else ' '
  185.     SuffixStr = ']' if y + 1 == D[0] else ''
  186.     LStr = ''
  187.     for x in range(D[1]):
  188.       Value = A[y][x]
  189.       if type(Value) == FloatType:
  190.         S = "%.4g" % (1.0 * Value)
  191.       else:
  192.         S = str(Value)
  193.       LStr += S if x == 0 else ", " + S
  194.     print PrefixStr + "[" + LStr + "]" + SuffixStr
  195.   return
  196.  
  197. Env = tEnv(W, 0.8, 1.0)
  198.  
  199. V = InitalValueMap(Env)
  200.  
  201. for i in range(500):
  202.   UpdateValues(Env, V)
  203.  
  204. PrintNice(V)
  205. PrintNice(Policy(Env, V))
  206.  
  207. # 85.18  89.40  93.15   100
  208. # 81.43  #####  68.37  -100
  209. # 77.21  73.46  69.56  47.39
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement