提交 48a14134 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix bug in local_div_to_inv

上级 6334f287
...@@ -2565,7 +2565,15 @@ register_canonicalize(local_mul_zero) ...@@ -2565,7 +2565,15 @@ register_canonicalize(local_mul_zero)
@gof.local_optimizer([T.true_div]) @gof.local_optimizer([T.true_div])
def local_div_to_inv(node): def local_div_to_inv(node):
if node.op == T.true_div and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == 1.0): if node.op == T.true_div and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))] out = node.outputs[0]
new_out = T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting
if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype)
# The ones could have forced a specific length
if new_out.type != out.type:
new_out = broadcast_like(new_out, out, node.env)
return [new_out]
else: else:
return False return False
register_specialize(local_div_to_inv) register_specialize(local_div_to_inv)
......
...@@ -2862,6 +2862,23 @@ def test_local_scalar_tensor_scalar(): ...@@ -2862,6 +2862,23 @@ def test_local_scalar_tensor_scalar():
assert len(cast_nodes) == 0 assert len(cast_nodes) == 0
f(0) f(0)
def test_local_div_to_inv():
num_len_s = tensor.lscalar('num_len')
denom_s = tensor.scalar('denom')
num_v = tensor.alloc(1, num_len_s)
denom_m = denom_s.dimshuffle('x', 'x')
out = num_v / denom_m
theano.printing.debugprint(out, print_type=True)
print out.broadcastable
assert numpy.all(out.broadcastable == (True, False))
f = theano.function([num_len_s, denom_s], out)
out_val = f(3, 2.)
assert out_val.shape == (1, 3)
assert numpy.allclose(out_val, 0.5)
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论