提交 26ed389c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix a case in a rarely-applied optimization, add test.

上级 e32e9d5f
...@@ -2742,10 +2742,10 @@ def local_add_specialize(node): ...@@ -2742,10 +2742,10 @@ def local_add_specialize(node):
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype
if len(new_inputs) == 0: if len(new_inputs) == 0:
#we got rid of the entire expression! #we got rid of the entire expression!
ndim = node.outputs[0].type.ndim ndim = node.outputs[0].type.ndim
dtype = node.outputs[0].type.dtype
return fill_chain( return fill_chain(
T.TensorConstant( T.TensorConstant(
T.TensorType( T.TensorType(
...@@ -2754,9 +2754,14 @@ def local_add_specialize(node): ...@@ -2754,9 +2754,14 @@ def local_add_specialize(node):
numpy.zeros((1,)*ndim, dtype=dtype))) numpy.zeros((1,)*ndim, dtype=dtype)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
return fill_chain(new_inputs[0]) ret = fill_chain(new_inputs[0])
else: else:
return fill_chain(T.add(*new_inputs)) ret = fill_chain(T.add(*new_inputs))
# The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0.
if ret[0].dtype != dtype:
ret = [T.cast(ret[0], dtype)]
return ret
else: else:
return False return False
register_specialize(local_add_specialize) register_specialize(local_add_specialize)
......
...@@ -2753,6 +2753,14 @@ def test_local_add_specialize(): ...@@ -2753,6 +2753,14 @@ def test_local_add_specialize():
s = tensor.add(tensor.zeros_like(a)) s = tensor.add(tensor.zeros_like(a))
assert local_add_specialize.transform(s.owner) assert local_add_specialize.transform(s.owner)
# Test when the 0 input is forcing upcasting
a = tensor.constant(0, dtype='int64')
b = tensor.constant(1, dtype='int32')
s = a + b
transformed = local_add_specialize.transform(s.owner)
assert transformed
assert transformed[0].type == s.type
def test_local_tensor_scalar_tensor(): def test_local_tensor_scalar_tensor():
dtypes = ['int8', 'int16', 'int32', 'int64', dtypes = ['int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64', 'uint8', 'uint16', 'uint32', 'uint64',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论