Advertisement
Guest User

toy baum welch

a guest
Jan 19th, 2012
3,231
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.21 KB | None | 0 0
  1. #!/usr/bin/python
  2. # Trivial toy implementation of forward-backward
  3. # Follows Jason Eisner's SpreadSheet-based teaching tool,
  4. # "An Interactive Spreadsheet for Teaching the Forward-Backward Algorithm (2002)"
  5. #  http://www.cs.jhu.edu/~jason/papers/
  6. #  and produces the same results.
  7. import simplejson
  8. import pprint
  9.  
  10. class forwardBackward( ):
  11.  
  12.     def __init__( self, obs, init_file, state2obs_file, state2state_file, final_file ):
  13.         """Initialize the probabilitiy matrices."""
  14.         self.alphas      = []
  15.         self.betas       = []
  16.         self.init_mat    = self._load_json_from_file(init_file)
  17.         self.state2obs   = self._load_json_from_file(state2obs_file)
  18.         self.state2state = self._load_json_from_file(state2state_file)
  19.         self.final_mat   = self._load_json_from_file(final_file)
  20.         self.obs         = obs.split(" ")
  21.         #Temporary receptacles for re-estimated values
  22.         self.re_init_mat    = {}
  23.         self.re_state2obs   = {}
  24.         self.re_state2state = {}
  25.         self.re_final_mat   = {}
  26.        
  27.     def _load_json_from_file( self, infile ):
  28.         """Load a json object from a file."""
  29.         json_obj = simplejson.loads(open(infile,"r").read( ))
  30.         return json_obj
  31.  
  32.     def _init_alphas( self ):
  33.         """
  34.           Initialize the forward alpha probabilities.
  35.        """
  36.         for state in self.init_mat:
  37.             prob   = self.init_mat[state] * self.state2obs[state][self.obs[0]]
  38.             if len(self.alphas)>0:
  39.                 self.alphas[0][state] = [prob, "START %s" % (state)]
  40.             else:
  41.                 self.alphas = [ {} for i in xrange(len(self.obs)) ]
  42.                 self.alphas[0] = { state : [prob, "START %s" % (state)] }
  43.  
  44.         return
  45.  
  46.     def _init_betas( self ):
  47.         """
  48.           Initialize the beta backward probabilities.
  49.        """
  50.         self.betas = [ {} for i in xrange(len(self.obs)) ]
  51.         for state in self.final_mat:
  52.             self.betas[len(self.obs)-1][state] = [self.final_mat[state], "END %s %s" % (state, state)]
  53.         return
  54.  
  55.     def _calc_forward_alphas( self ):
  56.         """
  57.           Calculate all of the forward alpha probabilities
  58.           The partial best paths:
  59.           GIVEN the observation for stage N
  60.           FOREACH state Y AT stage N,
  61.             FOREACH state X AT stage N-1
  62.               CALCULATE prob(state Y | observation) *
  63.                         prob(state Y | previous state at stage N-1 was X) *
  64.                         prob(partial best path from stage N-1)
  65.             RECORD total probabiity of path to state Y AT stage N
  66.        """
  67.         for i,o in enumerate(self.obs):
  68.             if i==0: continue
  69.             max_prob  = 0; max_state = 0
  70.             for curr in self.state2obs:
  71.                 curr_prob = 0; max_prob = 0; max_state = 0
  72.                 xmax = 0; tmps = None
  73.                 for prev in self.state2state:
  74.                     val, states = self.alphas[i-1][prev]
  75.                     subtot = self.state2obs[curr][self.obs[i]] \
  76.                         * self.state2state[prev][curr] \
  77.                         * val
  78.                     curr_prob += subtot
  79.                     if subtot > xmax:
  80.                         xmax = subtot
  81.                         tmps = states
  82.                 if curr_prob==0.0:
  83.                     print "ERROR: curr_prob==0.0.  Failed to reach final state for the current iteration."
  84.                     sys.exit()
  85.                 if curr_prob > max_prob:
  86.                     max_state = tmps
  87.                     max_prob  = curr_prob
  88.                 self.alphas[i][curr] = [max_prob, "%s %s" % (max_state, curr)]
  89.         return
  90.  
  91.     def _calc_backward_betas( self ):
  92.         """
  93.           Same process as _calc_forward_alphas, but in reverse.
  94.        """
  95.         for i in xrange(len(self.obs)-2,-1,-1):
  96.             max_prob = 0; max_state = 0
  97.             for curr in self.state2obs:
  98.                 curr_prob = 0; max_prob = 0; max_state = 0
  99.                 xmax = 0; tmps = None
  100.                 for prev in self.state2state:
  101.                     val, states = self.betas[i+1][prev]
  102.                     subtot = self.state2obs[prev][self.obs[i+1]] \
  103.                         * self.state2state[curr][prev] \
  104.                         * val
  105.                     curr_prob += subtot
  106.                     if subtot > xmax:
  107.                         xmax = subtot
  108.                         tmps = states
  109.                 if curr_prob > max_prob:
  110.                     max_state = tmps
  111.                     max_prob  = curr_prob
  112.                 self.betas[i][curr] = [max_prob, "%s %s" % (max_state, curr)]
  113.         return
  114.  
  115.     def best_path( self ):
  116.         """Viterbi best path through the lattice."""
  117.         top_prob = 0.0
  118.         best_path = None
  119.         for state in self.alphas[len(self.alphas)-1]:
  120.             val, states = self.alphas[len(self.alphas)-1][state]
  121.             if val > top_prob:
  122.                 top_prob = val
  123.                 best_path = states
  124.         print "%s\t%s" %(best_path,str(top_prob))
  125.         return top_prob
  126.  
  127.     def _reestimate_probs( self ):
  128.         """
  129.           Reestimate all the transition probabilities.
  130.  
  131.           This is a 3-step process,
  132.  
  133.           1. Compute the alpha-beta values, normalize and store them
  134.           2. Re-normalize the init, s2s, s2o and final matrices using the
  135.               results of 1.
  136.        """
  137.        
  138.         reest = {}
  139.         gtot  = 0.0
  140.         ssum  = 0.0
  141.         state_totals = {}
  142.        
  143.         #Iterate through the alphas and betas and compute
  144.         # the combined alpha-beta values
  145.         for j,val in enumerate(self.alphas):
  146.             sum = 0.0
  147.             ab_vals = {}
  148.             for key in self.alphas[j]:
  149.                 alpha = self.alphas[j][key]
  150.                 beta  = self.betas[j][key]
  151.                 alphaBeta = alpha[0] * beta[0]
  152.                 ab_vals[key]  = alphaBeta
  153.                 sum      += alphaBeta
  154.             ssum = sum
  155.             for key in ab_vals:
  156.                 total = ab_vals[key] / sum
  157.                 if reest.has_key(key):
  158.                     if reest[key].has_key(self.obs[j]):
  159.                         reest[key][self.obs[j]] += total
  160.                     else:
  161.                         reest[key][self.obs[j]]  = total
  162.                 else:
  163.                     reest[key] = { self.obs[j]:total }
  164.                    
  165.                 if state_totals.has_key(key):
  166.                     state_totals[key] += total
  167.                 else:
  168.                     state_totals[key]  = total
  169.                 gtot += total
  170.                
  171.         #Compute the reestimated state-2-observation matrix
  172.         for key in reest:
  173.             for s1 in reest[key]:
  174.                 reestimated = reest[key][s1] / state_totals[key]
  175.                 if self.re_state2obs.has_key(key):
  176.                     self.re_state2obs[key][s1] = reestimated
  177.                 else:
  178.                     self.re_state2obs[key] = {s1:reestimated}
  179.  
  180.         #Compute the reestimated init matrix
  181.         for key in self.alphas[0]:
  182.             alpha     = self.alphas[0][key]
  183.             beta      = self.betas[0][key]
  184.             alphaBeta = alpha[0] * beta[0]
  185.             tri       = alphaBeta / ssum
  186.             self.re_init_mat[key] = tri
  187.  
  188.         #Compute the reestimated state-2-state matrix
  189.         transitions = {}
  190.         for j,o in enumerate(self.obs):
  191.             if j==0: continue
  192.             for s1 in self.state2state:
  193.                 for s2 in  self.state2state[s1]:
  194.                     al    = self.alphas[j-1][s1][0]
  195.                     be    = self.betas[j][s2][0]
  196.                     trans = self.state2state[s1][s2]
  197.                     conf  = self.state2obs[s2][self.obs[j]]
  198.                     stot  = al * be * trans * conf / ssum
  199.                     if transitions.has_key(s1):
  200.                         if transitions[s1].has_key(s2):
  201.                             transitions[s1][s2] += stot
  202.                         else:
  203.                             transitions[s1][s2] = stot
  204.                     else:
  205.                         transitions[s1] = {s2:stot}
  206.  
  207.         for s1 in transitions:
  208.             for s2 in transitions[s1]:
  209.                 newP = transitions[s1][s2] / state_totals[s1]
  210.                 if self.re_state2state.has_key(s1):
  211.                     self.re_state2state[s1][s2] = newP
  212.                 else:
  213.                     self.re_state2state[s1] = {s2:newP}
  214.  
  215.         #Compute the reestimated final matrix
  216.         for key in self.alphas[len(self.alphas)-1]:
  217.             alpha     = self.alphas[len(self.alphas)-1][key]
  218.             beta      = self.betas[len(self.alphas)-1][key]
  219.             alphaBeta = alpha[0] * beta[0]
  220.             tri = alphaBeta / ssum / state_totals[key]
  221.             self.re_final_mat[key] = tri
  222.  
  223.         self.init_mat    = self.re_init_mat
  224.         self.state2obs   = self.re_state2obs
  225.         self.state2state = self.re_state2state
  226.         self.final_mat   = self.re_final_mat
  227.         #Reset the alphas and betas
  228.         self.alphas      = []
  229.         self.betas       = []
  230.         return
  231.  
  232.     def iterate( self, n_iter=5, ratio=4e-20,verbose=False ):
  233.         """Run the algorithm iteratively until convergence."""
  234.         prev_likelihood = 9999999
  235.         for i in range(n_iter):
  236.             if verbose:
  237.                 print "Iteration: %d" % i
  238.                 print "S2O"
  239.                 pprint.pprint(fb.state2obs)
  240.                 print "S2S"
  241.                 pprint.pprint(fb.state2state)
  242.             fb._init_alphas()
  243.             fb._init_betas()
  244.             fb._calc_forward_alphas()
  245.             fb._calc_backward_betas()
  246.  
  247.             likelihood = fb.best_path()
  248.             if abs(likelihood - prev_likelihood)<ratio:
  249.                 print "Achieved convergence ratio:", ratio,"; Stopping at iteration: %d." % i
  250.                 break
  251.             elif verbose:
  252.                 print "Likelihood change:", abs(likelihood - prev_likelihood)
  253.             fb._reestimate_probs()            
  254.             prev_likelihood = likelihood
  255.         return
  256.  
  257. if __name__=="__main__":
  258.     import sys, argparse
  259.  
  260.     parser = argparse.ArgumentParser()
  261.     parser.add_argument('--init',    "-i", help='2D Initialization probability matrix. JSON format.', required=True)
  262.     parser.add_argument('--s2o',     "-o", help='3D state-to-observation transition matrix. JSON format.', required=True)
  263.     parser.add_argument('--s2s',     "-s", help='3d state-to-state transition matrix. JSON format.', required=True)
  264.     parser.add_argument('--final',   "-f", help='2D Final probability matrix. JSON format.', required=True)
  265.     parser.add_argument('--obs',     "-b", help='Observation sequence.  Must correspond to observations values in state-to-observation matrix.')
  266.     parser.add_argument('--n_iter',  "-n", type=int,   default=10,    help='Maximum number of iterations for forward-backward algorithm.')
  267.     parser.add_argument('--ratio',   "-r", type=float, default=4e-20, help='Convergence ratio.')
  268.     parser.add_argument('--verbose', "-v", type=bool,  default=False, help="Verbosity.")
  269.     args = parser.parse_args()
  270.  
  271.     fb = forwardBackward(args.obs, args.init, args.s2o, args.s2s, args.final)
  272.     fb.iterate(n_iter=args.n_iter, ratio=args.ratio, verbose=args.verbose)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement