Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- def resolve_if_node(child1, child2, child3, x_size, y_size):
- #return tf.cond(child3, lambda: child1, lambda: child2)
- #This code is sin but it works
- return tf.math.add(tf.math.multiply(tf.math.divide_no_nan(child3, child3), child1), tf.math.multiply(tf.math.subtract(tf.constant(1, tf.float32, [x_size, y_size]), tf.math.divide_no_nan(child3, child3)), child2))
- aux1 = tf.convert_to_tensor([[1.0, 2.0], [4.0, 5.0]])
- aux2 = tf.constant(3, shape=[2,2], dtype=tf.float32)
- foo = n.resolve_if_node(tf.constant(0,shape=[2,2], dtype=tf.float32), tf.constant(1,shape=[2,2], dtype=tf.float32), tf.cast(tf.math.greater(aux1, aux2), tf.float32), 2, 2)
- print(run_tensor(foo).all() == run_tensor(tf.convert_to_tensor([[0,0], [1,1]])).all())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement