Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import generator as gn
- def Calc_alpha(A,B,Pi,V,O,T,N,K):
- alpha = []
- for k in range(K):
- alpha.append([])
- column = V.index(O[k][0])
- alpha[k].append([Pi[i]*B[i][column] for i in range(N)])
- for t in range(1,T):
- cur = []
- column = V.index(O[k][t])
- for i in range(N):
- temp = sum([alpha[k][t-1][j] * A[j][i] for j in range(N)])
- cur.append(B[i][column] * temp)
- alpha[k].append(cur)
- return alpha
- def Calc_betta(A,B,V,O,T,N,K):
- betta = []
- for k in range(K):
- betta.append([])
- betta[k].append([1,1,1])
- for t in range(1,T):
- cur = []
- column = V.index(O[k][T-t])
- for i in range(N):
- cur.append(sum([betta[k][t - 1][j] * B[j][column]*A[i][j] for j in range(N)]))
- betta[k].append(cur)
- for i in betta:
- i.reverse()
- return betta
- def Calc_gamma(alpha,betta,T,N,K):
- gamma = []
- for k in range(K):
- gamma.append([])
- P = sum(alpha[k][T-1])
- for t in range(T):
- gamma[k].append([(alpha[k][t][i] * betta[k][t][i]) / P for i in range(N)])
- return gamma
- def Calc_ksi(alpha,betta,A,B,O,V,T,N,K):
- ksi = []
- for i in range(K):
- ksi.append([])
- P = sum(alpha[i][T-1])
- for j in range(T-1):
- ksi[i].append([])
- cur = []
- column = V.index(O[i][j+1])
- for k in range(N):
- cur_k = []
- for z in range(N):
- cur_k.append((alpha[i][j][k] * A[k][z]*B[z][column]*betta[i][j+1][z]) / P)
- cur.append(cur_k)
- ksi[i][j].append(cur)
- return ksi
- def Estimate_Pi(gamma,K,N):
- return [sum([k[0][n] for k in gamma]) / K for n in range(N)]
- def Estimate_A(ksi,gamma,N,K,T):
- est_A = []
- for i in range(N):
- cur = []
- sum_gamma = sum([sum([t[i] for t in k]) for k in gamma])
- for j in range(N):
- sum_ksi = sum([sum([t[0][i][j] for t in k]) for k in ksi])
- cur.append(sum_ksi/sum_gamma)
- est_A.append(cur)
- return est_A
- def Estimate_B(gamma,N,K,T,O,V):
- est_B = []
- for i in range(N):
- cur = []
- sum_gamma = sum([sum([t[i] for t in k]) for k in gamma])
- for j in range(N):
- sum_gamma_V = 0
- for k in range(K):
- for t in range(T-1):
- if V[j] == O[k][t]:
- sum_gamma_V += gamma[k][t][i]
- cur.append(sum_gamma_V/sum_gamma)
- est_B.append(cur)
- return est_B
- def Calc_Log(alpha,K,T):
- L = 0
- for k in range(K):
- L += np.log(sum(alpha[k][T-1]))
- return L
- def iter_algoritm(A,B,O,Pi,V,K,T,N):
- alpha = Calc_alpha(A,B,Pi,V,O,T,len(B),K)
- betta = Calc_betta(A,B,V,O,T,len(B),K)
- gamma = Calc_gamma(alpha,betta,T,len(B),K)
- ksi = Calc_ksi(alpha,betta,A,B,O,V,T,len(B),K)
- est_pi = Estimate_Pi(gamma,K,N)
- est_A = Estimate_A(ksi,gamma,len(B),K,T)
- est_B = Estimate_B(gamma,len(B),K,T,O,V)
- IterL = Calc_Log(alpha,K,T)
- IterL_next = Calc_Log(Calc_alpha(est_A,est_B,est_pi,V,O,T,len(B),K),K,T)
- accuracy = abs(IterL_next - IterL)
- print(est_B)
- flag = accuracy > 1e-7
- return est_A,est_B,est_pi,flag,accuracy
- def Baum_Welch(A,B,O,Pi,V,K,T,N,max_iter):
- pred_A,pred_B,pred_Pi = A,B,Pi
- next_A,next_B,next_Pi,flag,accuracy = iter_algoritm(pred_A,pred_B,O,pred_Pi,V,K,T,N)
- iteration = 1
- while flag and iteration < max_iter:
- iteration += 1
- pred_A,pred_B,pred_Pi = next_A,next_B,next_Pi
- next_A,next_B,next_Pi,flag,accuracy = iter_algoritm(pred_A,pred_B,O,pred_Pi,V,K,T,N)
- print('Iteration:',iteration)
- print('Accuracy:',accuracy)
- print('Iteration:',iteration)
- print('Accuracy:',accuracy)
- print(pred_B)
- def get_params():
- file = open('input_test_1.txt','r')
- A = []
- B = []
- Pi = []
- if file.readline() == 'Pi\n':
- buff = file.readline()
- while buff != '\n':
- buff = buff.split(' ')
- for i in buff:
- Pi.append(float(i))
- buff = file.readline()
- else:
- print('Ошибка считывания файла в блоке V')
- if file.readline() == 'A\n':
- buff = file.readline()
- while buff != '\n':
- buff = buff.split(' ')
- A.append([float(i) for i in buff])
- buff = file.readline()
- else:
- print('Ошибка считывания файла в блоке A')
- if file.readline() == 'B\n':
- buff = file.readline()
- while buff:
- buff = buff.split(' ')
- B.append([float(i) for i in buff])
- buff = file.readline()
- else:
- print('Ошибка считывания файла в блоке B')
- file.close()
- return A,B,Pi
- def main():
- print('Введите K:')
- K = 100
- print('Введите T:')
- T = 100
- print('Введите максимально кол-ва итераций:')
- M_iter = 100
- #int(input())
- O,V = gn.test(K,T)
- A,B,Pi = get_params()
- Baum_Welch(A,B,O,Pi,V,K,T,len(B),M_iter)
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement