Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def compute_transition_matrix(order, l0, l1, l2, C_t, C_tt, C_ttt, num_tokens):
- """
- """
- # for order = 2
- if order == 2:
- transition_matrix = defaultdict(lambda: defaultdict(int))
- for key1 in C_tt:
- for key2 in C_tt[key1]:
- if C_t[key1] != 0:
- transition_matrix[key1][key2] = l1 * (C_tt[key1][key2] / float(C_t[key1])) + l0 * (C_t[key2] / float(num_tokens))
- # for order = 3
- else:
- transition_matrix = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
- for key1 in C_ttt:
- for key2 in C_ttt[key1]:
- for key3 in C_ttt[key2]:
- if C_tt[key1][key2] == 0:
- transition_matrix[key1][key2][key3] = l1 * (C_tt[key2][key3] / float(C_t[key2])) + l0 * (C_t[key3] / float(num_tokens))
- elif C_tt[key1][key2] != 0 and C_t[key2] != 0:
- transition_matrix[key1][key2][key3] = l2 * (C_ttt[key1][key2][key3] / float(C_tt[key1][key2])) + l1 * (C_tt[key2][key3] / float(C_t[key2])) + l0 * (C_t[key3] / float(num_tokens))
- return transition_matrix
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement