提交 ee20da96 authored 作者: James Bergstra's avatar James Bergstra

Fixed bug in local_add_specialize that would create a zero ndarray w wrong number of dims.

上级 1ce4dd26
...@@ -2313,15 +2313,21 @@ def local_add_specialize(node): ...@@ -2313,15 +2313,21 @@ def local_add_specialize(node):
y = get_constant_value(input) y = get_constant_value(input)
except TypeError: except TypeError:
y = input y = input
if N.all(y == 0.0): if numpy.all(y == 0.0):
continue continue
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0: if len(new_inputs) == 0:
#we got rid of the entire expression! #we got rid of the entire expression!
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype, ndim = node.outputs[0].type.ndim
broadcastable = [True] * node.outputs[0].ndim), N.asarray(0))) dtype = node.outputs[0].type.dtype
return fill_chain(
T.TensorConstant(
T.TensorType(
dtype=dtype,
broadcastable = [True] * ndim),
numpy.zeros((1,)*ndim, dtype=dtype)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
return fill_chain(new_inputs[0]) return fill_chain(new_inputs[0])
......
...@@ -2174,3 +2174,15 @@ def test_local_mul_to_neg(): ...@@ -2174,3 +2174,15 @@ def test_local_mul_to_neg():
aval = numpy.random.randint(0,10,(2,2)).astype('int32') aval = numpy.random.randint(0,10,(2,2)).astype('int32')
assert f1(aval).dtype == a.dtype assert f1(aval).dtype == a.dtype
assert f2(aval).dtype == 'float64' assert f2(aval).dtype == 'float64'
def test_local_add_specialize():
# test of non-zero dimension
a = TT.vector()
s = TT.add(TT.zeros_like(a))
assert local_add_specialize.transform(s.owner)
# test of 0-d
a = TT.scalar()
s = TT.add(TT.zeros_like(a))
assert local_add_specialize.transform(s.owner)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论