Advertisement
Guest User

Untitled

a guest
Feb 22nd, 2017
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.07 KB | None | 0 0
  1. import tensorflow as tf
  2. import scipy.sparse as ss
  3. import numpy as np
  4.  
  5.  
  6. def scipy_run(sess, result, X, tf_indices, tf_id_values, tf_weight_values,
  7. tf_dense_shape):
  8. row_nnz = np.diff(X.indptr)
  9. indices = np.asarray([[row_i, col_i]
  10. for row_i, nnz in enumerate(row_nnz)
  11. for col_i in range(nnz)], dtype=np.int64)
  12. ids = X.indices.astype(np.int64)
  13. weights = X.data
  14. tf_result = sess.run(result, {tf_indices: indices,
  15. tf_id_values: ids,
  16. tf_weight_values: weights,
  17. tf_dense_shape: X.shape})
  18. return tf_result
  19.  
  20. def main():
  21. X = ss.csr_matrix([[1, 0, 0],
  22. [1, 0, 1],
  23. [0, 0, 1],
  24. [0, 2, 0]], dtype=np.float32)
  25. W = np.asarray([[.1, .2, .3],
  26. [.2, .3, .1],
  27. [.3, .2, .2]], dtype=np.float32)
  28. b = np.asarray([[.1, .2, .3]], dtype=np.float32)
  29.  
  30. # scipy version
  31. direct_result = X @ W + b
  32.  
  33. print(direct_result)
  34.  
  35. # tensorflow version
  36. tf_indices = tf.placeholder(tf.int64, [None, 2])
  37. tf_id_values = tf.placeholder(tf.int64, [None])
  38. tf_weight_values = tf.placeholder(tf.float32, [None])
  39. tf_dense_shape = tf.placeholder(tf.int64, [2])
  40. sp_ids = tf.SparseTensor(tf_indices, tf_id_values, tf_dense_shape)
  41. sp_weights = tf.SparseTensor(tf_indices, tf_weight_values, tf_dense_shape)
  42.  
  43. tf_W = tf.Variable(W, tf.float32)
  44. tf_b = tf.Variable(b, tf.float32)
  45. result = tf.nn.embedding_lookup_sparse(tf_W, sp_ids, sp_weights,
  46. combiner='sum') + tf_b
  47.  
  48. init = tf.global_variables_initializer()
  49. sess = tf.Session()
  50. sess.run(init)
  51. tf_result = scipy_run(sess, result, X, tf_indices, tf_id_values,
  52. tf_weight_values, tf_dense_shape)
  53.  
  54. print(tf_result)
  55.  
  56.  
  57. if (direct_result == tf_result).all():
  58. print("They are the same!")
  59. else:
  60. print("They are different!")
  61.  
  62.  
  63. if __name__ == '__main__':
  64. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement