Advertisement
Guest User

Untitled

a guest
Jun 23rd, 2017
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.76 KB | None | 0 0
  1. def test_embedding_look_up():
  2. input_ids = tf.placeholder(dtype=tf.int32, shape=[3,2])
  3.  
  4. #embedding = tf.get_variable('test', shape=[5,5])
  5. embedding = tf.Variable(np.identity(5, dtype=np.int32))
  6. input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
  7.  
  8. sess = tf.InteractiveSession()
  9. sess.run(tf.global_variables_initializer())
  10. print("embeding:\n", embedding.eval())
  11. result = sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]})
  12. print("结果:\n", result)
  13. print(result.shape)
  14. test_embedding_look_up()
  15.  
  16. outputs:
  17. embeding:
  18. [[1 0 0 0 0]
  19. [0 1 0 0 0]
  20. [0 0 1 0 0]
  21. [0 0 0 1 0]
  22. [0 0 0 0 1]]
  23. 结果:
  24. [[[0 1 0 0 0]
  25. [0 0 1 0 0]]
  26.  
  27. [[0 0 1 0 0]
  28. [0 1 0 0 0]]
  29.  
  30. [[0 0 0 1 0]
  31. [0 0 0 1 0]]]
  32. (3, 2, 5)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement