提交 01f44306 authored 作者: nouiz's avatar nouiz

Merge pull request #472 from lamblin/fix_local_upcast_elemwise_constant_inputs

Fix local_upcast_elemwise_constant_inputs optimization
...@@ -9,6 +9,9 @@ Bug fixes (the result changed): ...@@ -9,6 +9,9 @@ Bug fixes (the result changed):
Crashes fixes: Crashes fixes:
* More cases supported in AdvancedIncSubtensor1. (Olivier D.) * More cases supported in AdvancedIncSubtensor1. (Olivier D.)
* Fix crash when a broadcasted constant was used as input of an
elemwise Op and needed to be upcasted to match the op's output.
(Reported by John Salvatier, fixed by Pascal L.)
Interface change: Interface change:
* The Theano flag "nvcc.flags" is now included in the hard part of the key. * The Theano flag "nvcc.flags" is now included in the hard part of the key.
......
...@@ -1443,7 +1443,9 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1443,7 +1443,9 @@ def local_upcast_elemwise_constant_inputs(node):
# works only for scalars # works only for scalars
cval_i = get_constant_value(i) cval_i = get_constant_value(i)
if all(i.broadcastable): if all(i.broadcastable):
new_inputs.append(T.cast(cval_i, output_dtype)) new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype),
i.ndim))
else: else:
if shape_i is None: if shape_i is None:
return return
......
...@@ -3532,6 +3532,12 @@ class Test_lift_transpose_through_dot(unittest.TestCase): ...@@ -3532,6 +3532,12 @@ class Test_lift_transpose_through_dot(unittest.TestCase):
assert str(g) == sg assert str(g) == sg
def test_local_upcast_elemwise_constant_inputs():
s = dvector("s")
x = tensor.sum(tensor.log(10**s))
f = function([s], [tensor.grad(x, s)])
f([-42, -2.1, -1, -0.5, 0, 0.2, 1, 2, 12])
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论