Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- void gradient(INT e1_a, INT e2_a, INT rel_a, INT e1_b, INT e2_b, INT rel_b) {
- INT lasta1 = e1_a * dimension;
- INT lasta2 = e2_a * dimension;
- INT lastar = rel_a * dimension;
- INT lastb1 = e1_b * dimension;
- INT lastb2 = e2_b * dimension;
- INT lastbr = rel_b * dimension;
- for (INT ii=0; ii < dimension; ii++) {
- REAL x;
- x = (entityVec[lasta2 + ii] - entityVec[lasta1 + ii] - relationVec[lastar + ii]);
- if (x > 0)
- x = -alpha;
- else
- x = alpha;
- relationVec[lastar + ii] -= x;
- entityVec[lasta1 + ii] -= x;
- entityVec[lasta2 + ii] += x;
- x = (entityVec[lastb2 + ii] - entityVec[lastb1 + ii] - relationVec[lastbr + ii]);
- if (x > 0)
- x = alpha;
- else
- x = -alpha;
- relationVec[lastbr + ii] -= x;
- entityVec[lastb1 + ii] -= x;
- entityVec[lastb2 + ii] += x;
- }
- }
- void train_kb(INT e1_a, INT e2_a, INT rel_a, INT e1_b, INT e2_b, INT rel_b) {
- REAL sum1 = calc_sum(e1_a, e2_a, rel_a);
- REAL sum2 = calc_sum(e1_b, e2_b, rel_b);
- //printf(" %f %f\r", sum1, sum2);
- if (sum1 + margin > sum2) {
- //res += margin + sum1 - sum2;
- //gradient(e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
- res += (margin+sum1-sum2) * (margin+sum1-sum2) * 0.5;
- gradient_l2(e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement