Advertisement
Guest User

MLE Gaussian Example (fixed)

a guest
Jan 28th, 2013
452
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.65 KB | None | 0 0
  1. import math
  2. import random
  3. from pylab import hist, show
  4.  
  5. #FUNCTIONS
  6.  
  7. def getNormalizerComponent(mu,sigma,dataRange):
  8. #make the resolution on this large.
  9.  n=0.
  10.  for count in range(dataRange[0]*100,dataRange[1]*100):
  11.   #density * volume
  12.   n+=(math.exp((-1.0/2.0)*pow(float(count)*0.01-mu,2)/pow(sigma,2)))*0.01
  13.  
  14.  return(n)
  15.  
  16.  
  17. #MAIN
  18.  
  19. print "Demonstration of MLE disentangling two mixed, truncated, unnormalized Gaussians"
  20.  
  21. #DECLARE THE TRUE DATA PARAMETERS (ANSWERS)
  22. trueSig1=0.9
  23. trueSig2=0.9
  24. trueMu1=4.
  25. trueMu2=8.
  26. trueRatio=0.2
  27. dataSize=100.
  28. range1=[-2,7]
  29. range2=[4,12]
  30.  
  31.  
  32. #DECLARE THE SEARCH SPACE
  33.  #brute grid
  34. sigma1List=[0.3,0.4,0.5,0.6,0.7,0.8,0.9]
  35. sigma2List=[0.3,0.4,0.5,0.6,0.7,0.8,0.9]
  36. mu1List=[1.,2.,3.,4.,5.,6.,7.,8.,9.,10.]
  37. mu2List=[1.,2.,3.,4.,5.,6.,7.,8.,9.,10.]
  38. ratioList=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
  39.  
  40. print "Generating Data..."
  41. #CREATE DATA ARRAY X
  42. x=[]
  43. #gauss1
  44. count=0
  45. while count < int(dataSize*trueRatio):
  46.  guess=random.gauss(trueMu1,trueSig1)
  47.  if guess > range1[0] and guess < range1[1]:
  48.   x.append(guess)
  49.   count+=1
  50.  
  51. #gauss2
  52. count=0
  53. while count < int(dataSize*(1.0-trueRatio)):
  54.  guess=random.gauss(trueMu2,trueSig2)
  55.  if guess > range2[0] and guess < range2[1]:
  56.   x.append(guess)
  57.   count+=1
  58.  
  59. #hist(x)
  60. #show()
  61. o=open("outfile.csv","w")
  62.  
  63. #LOOP THROUGH PARAMETERS IN A BRUTE FORCE GRID SEARCH
  64. maxProb=-99999999999
  65. for mu1 in mu1List:
  66.  print "Sanity Counter: "+str(mu1)
  67.  for mu2 in mu2List:
  68.   for sigma1 in sigma1List:
  69.    for sigma2 in sigma2List:
  70.     for ratio in ratioList:
  71.      workingProb=0.
  72.  
  73.      n1=getNormalizerComponent(mu1,sigma1,range1)
  74.      n2=getNormalizerComponent(mu2,sigma2,range2)
  75.        
  76.      for item in x:
  77.       prob1=0.; prob2=0.;
  78.       if range1[0] < item < range1[1]:
  79.        prob1=(ratio)*(1./n1)*math.exp((-1.0/2.0)*pow(item-mu1,2)/pow(sigma1,2))
  80.       if range2[0] < item < range2[1]:
  81.        prob2=(1.-ratio)*(1./n2)*math.exp((-1.0/2.0)*pow(item-mu2,2)/pow(sigma2,2))
  82.       workingProb+=math.log((prob1+prob2))
  83.  
  84.      o.write(str(mu1)+","+str(mu2)+","+str(sigma1)+","+str(sigma2)+","+str(ratio)+","+str(workingProb)+"\n")
  85.                  
  86.      if workingProb > maxProb:
  87.       maxProb = workingProb
  88.       maxSigma1 = sigma1
  89.       maxSigma2 = sigma2
  90.       maxMu1 = mu1
  91.       maxMu2 = mu2
  92.       maxRatio=ratio
  93.  
  94. o.close()
  95.  
  96. print "Values (true :: max)"
  97. print "Gaussian 1: ratio ("+str(trueRatio)+":"+str(maxRatio)+") // sigma ("+str(trueSig1)+":"+str(maxSigma1)+") // mu ("+str(trueMu1)+":"+str(maxMu1)+")."
  98. print "Gaussian 2: ratio ("+str(1.0-trueRatio)+":"+str(1.0-maxRatio)+") //sigma ("+str(trueSig2)+":"+str(maxSigma2)+") // mu ("+str(trueMu2)+":"+str(maxMu2)+")."
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement