提交 f2eb3f4b authored 作者: Frederic Bastien's avatar Frederic Bastien

Finally, the 'bug' wasn't a bug as the optimization didn't applied in the case…

Finally, the 'bug' wasn't a bug as the optimization didn't applied in the case that would cause problems.
上级 312e640c
...@@ -5064,6 +5064,8 @@ def local_useless_elemwise_comparison(node): ...@@ -5064,6 +5064,8 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)] return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# It don't detect case when the 0 is all zeros with ndim > 0.
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
......
...@@ -3432,6 +3432,9 @@ def test_local_fill_useless(): ...@@ -3432,6 +3432,9 @@ def test_local_fill_useless():
class Test_local_useless_elemwise_comparison(unittest.TestCase): class Test_local_useless_elemwise_comparison(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_local_useless_elemwise_comparison(self): def test_local_useless_elemwise_comparison(self):
# TODO: test each case individually. # TODO: test each case individually.
# The following case is what made me discover those cases. # The following case is what made me discover those cases.
...@@ -3469,6 +3472,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3469,6 +3472,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
mode = theano.compile.get_default_mode().excluding('fusion') mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode) f = theano.function([X, Y], Z, mode=mode)
f(self.rng.rand(2, 3).astype(config.floatX),
self.rng.rand(2).astype(config.floatX))
# theano.printing.debugprint(f, print_type=True) # theano.printing.debugprint(f, print_type=True)
# here is the output for the debug print: # here is the output for the debug print:
""" """
...@@ -3571,9 +3576,15 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3571,9 +3576,15 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.minimum(x.shape[0], 0), mode=mode) f = theano.function([x], T.minimum(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
assert f(x_val) == 0
f = theano.function([x], T.minimum(0, x.shape[0]), mode=mode) f = theano.function([x], T.minimum(0, x.shape[0]), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
assert f(x_val) == 0
f = theano.function([x], T.minimum([0, 0], x.shape[0]), mode=mode)
# This case isn't optimized.
# self.assert_eqs_const(f, 0)
utt.assert_allclose(f(x_val), [0, 0])
def test_shape_add_inequality(self): def test_shape_add_inequality(self):
x = T.vector('x', dtype=config.floatX) x = T.vector('x', dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论