提交 70e866cd authored 作者: Frederic Bastien's avatar Frederic Bastien

Add opt Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X) that work…

Add opt Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X) that work around this comment by removing the re-inserted scan: https://github.com/Theano/Theano/pull/4356/files#r61901454
上级 cc28e44d
...@@ -4759,6 +4759,10 @@ def local_useless_elemwise_comparison(node): ...@@ -4759,6 +4759,10 @@ def local_useless_elemwise_comparison(node):
Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
# Shapes are never negativ
# Needed by Reshape.infer_shape
Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X)
""" """
if not isinstance(node.op, T.Elemwise): if not isinstance(node.op, T.Elemwise):
return return
...@@ -4834,6 +4838,20 @@ def local_useless_elemwise_comparison(node): ...@@ -4834,6 +4838,20 @@ def local_useless_elemwise_comparison(node):
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[EQ](Subtensor(Shape(x)), -N)
if (isinstance(node.op.scalar_op, scalar.EQ) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Subtensor) and
node.inputs[0].owner.inputs[0].owner and
isinstance(node.inputs[0].owner.inputs[0].owner.op, T.Shape)):
try:
cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True)
if cst < 0:
return [T.zeros_like(node.inputs[0],
dtype=node.outputs[0].dtype)]
except NotScalarConstantError:
pass
return return
......
...@@ -3552,6 +3552,19 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3552,6 +3552,19 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x, y], T.ge(x.shape[0]+y.shape[0], 0), mode=mode) f = theano.function([x, y], T.ge(x.shape[0]+y.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 1) self.assert_eqs_const(f, 1)
def test_subtensor_shape_equality(self):
x = T.vector('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison',
'local_shape_to_shape_i',
'local_track_shape_i',
'local_subtensor_make_vector')
f = theano.function([x], T.eq(x.shape[0], 0), mode=mode)
assert len(f.maker.fgraph.toposort()) == 2
f = theano.function([x], T.eq(x.shape[0], -1), mode=mode)
self.assert_eqs_const(f, 0)
def test_and(self): def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论