Advertisement
Guest User

Untitled

a guest
Jan 21st, 2017
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.45 KB | None | 0 0
  1. import numpy as np
  2. import collections
  3. import matplotlib.pyplot as plt
  4. #import nltk
  5.  
  6. ITERATIONS = 100
  7. px = [0,1,2]
  8. py = [0,1,0]
  9.  
  10. n = 50 # Number of documents
  11. k = 3 # Number of topics
  12. v = 100 # Number of distinct words
  13.  
  14. # Dirichlet parameters
  15. alpha = .01
  16. lamb = .01
  17.  
  18. # How much each document pertains to a perticular topic
  19. a = np.matrix(np.zeros((n,k)))
  20.  
  21. # How much each word pertains to a perticular topic
  22. b = np.matrix(np.zeros((k,v)))
  23.  
  24. # Generate random documents for testing
  25. minLength = 25
  26. maxLength = 50
  27.  
  28. sentenceLengths = range(int(np.random.uniform(minLength,maxLength)))
  29.  
  30. # Labeled dataset, each word is replaced with its topic index
  31. topicAssignment = [[int(np.random.uniform(0,k))
  32. for _ in sentenceLengths]
  33. for _ in range(n)]
  34.  
  35. # Labeled dataset, each word is replaced with its bag of words index
  36. wordIndexes = [[int(np.random.uniform(0,v))
  37. for _ in sentenceLengths]
  38. for _ in range(n)]
  39.  
  40. def rouletteArg(vector):
  41. # Uncomment the next line for greedy
  42. #return np.argmax(vector)
  43. vector /= np.sum(vector)
  44. val = np.random.uniform()
  45. #print(vector)
  46. for i in range(len(vector)):
  47. val -= vector[i]
  48. if val <= 0:
  49. return i
  50. return len(vector)-1
  51.  
  52. def update(a,b,topicAssignment):
  53. # Count the number of times words from each topic are used
  54. # in the document
  55. for document in range(n):
  56. occurance = collections.Counter(topicAssignment[document])
  57. for topic in range(k):
  58. a[document,topic] = occurance[topic]
  59. a[document] /= np.sum(a[document] + alpha)
  60.  
  61. # Count the number of times a word is used in a particular topic
  62. for document in range(n):
  63. doc = wordIndexes[document]
  64. for wordIndex in range(len(doc)):
  65. topic = topicAssignment[document][wordIndex]
  66. word = wordIndexes[document][wordIndex]
  67. b[topic,word] += 1
  68.  
  69. # Normalize
  70. for i in range(len(b.T)):
  71. b.T[i]/= np.sum(b.T[i] + lamb)
  72.  
  73. # Update the assignment of topics
  74. for document in range(n):
  75. for wordIndex in range(len(topicAssignment[document])):
  76. vec = np.zeros(k)
  77. for topic in range(k):
  78. word = wordIndexes[document][wordIndex]
  79. vec[topic] = (a[document,topic] + alpha) * (b[topic,word] + lamb)
  80. topicAssignment[document][wordIndex] = rouletteArg(vec)
  81.  
  82. return a,b,topicAssignment
  83.  
  84. a,b,topicAssignment = update(a,b,topicAssignment)
  85.  
  86. costs = np.zeros(ITERATIONS)
  87. for iteration in range(ITERATIONS):
  88. lastA = a.copy()
  89. lastB = b.copy()
  90. a,b,topicAssignment = update(a,b,topicAssignment)
  91. cost = np.sum(np.abs(a-lastA)) + np.sum(np.abs(b-lastB))
  92. costs[iteration] = cost
  93.  
  94. if cost <= 1e-7:
  95. break
  96.  
  97. plt.figure(1)
  98. x = a * np.matrix(px).T
  99. y = a * np.matrix(py).T
  100. plt.plot(x,y,'bo',alpha=((1.0-iteration/ITERATIONS)*.5))
  101.  
  102. plt.figure(2)
  103. x = b.T * np.matrix(px).T
  104. y = b.T * np.matrix(py).T
  105. plt.plot(x,y,'bo',alpha=((1.0-iteration/ITERATIONS)*.5))
  106.  
  107. print(iteration,cost)
  108.  
  109. plt.figure(1)
  110. fig = plt.gcf()
  111. fig.canvas.set_window_title('Document Distribution')
  112. x = a * np.matrix(px).T
  113. y = a * np.matrix(py).T
  114. plt.plot(x,y,'ro',ms=10)
  115. plt.plot(px,py,'k-')
  116.  
  117. plt.figure(2)
  118. fig = plt.gcf()
  119. fig.canvas.set_window_title('Word Distribution')
  120. x = b.T * np.matrix(px).T
  121. y = b.T * np.matrix(py).T
  122. plt.plot(x,y,'ro',ms=10)
  123. plt.plot(px,py,'k-')
  124.  
  125. plt.figure(3)
  126. fig = plt.gcf()
  127. fig.canvas.set_window_title('Cost Function')
  128. plt.plot(costs)
  129. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement