提交 c227d7c0 authored 作者: Frederic Bastien's avatar Frederic Bastien

tmp

上级 94abb92a
......@@ -4839,11 +4839,29 @@ def local_useless_elemwise_comparison(node):
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N)
# TODO: handle the case where the -N is on either side
"""
|Elemwise{eq,no_inplace} [id B] ''
| |Subtensor{int64} [id C] ''
| | |Join [id D] ''
| | | |TensorConstant{0} [id E]
| | | |Subtensor{int64:int64:} [id F] ''
| | | | |Shape [id G] ''
"""
def investigate(node):
" Return True if values will be shapes, so >= 0"
if isinstance(node.op, (T.Shape, Shape_i)):
return True
elif isinstance(node.op, Subtensor) and node.inputs[0].owner:
return investigate(node.inputs[0].owner)
elif isinstance(node.op, T.Join):
return all(v.owner and
investigate(v.owner) for v in node.inputs[1:])
if (isinstance(node.op.scalar_op, scalar.EQ) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Subtensor) and
node.inputs[0].owner.inputs[0].owner and
isinstance(node.inputs[0].owner.inputs[0].owner.op, T.Shape)):
investigate(node.inputs[0].owner)):
try:
cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True)
......
......@@ -3552,18 +3552,27 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x, y], T.ge(x.shape[0]+y.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 1)
def test_subtensor_shape_equality(self):
def test_equality_shapes(self):
# Test equality where one sides contain only shapes related
# stuff.
x = T.vector('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison',
'local_shape_to_shape_i',
'local_track_shape_i',
'local_subtensor_make_vector')
f = theano.function([x], T.eq(x.shape[0], 0), mode=mode)
assert len(f.maker.fgraph.toposort()) == 2
f = theano.function([x], T.eq(x.shape[0], -1), mode=mode)
self.assert_eqs_const(f, 0)
for g in [x.shape[0],
Shape_i(0)(x),
join(0,
x.shape[0:], # todo test reshape, dimshuffle
x.shape[0:1])]:
f = theano.function([x], T.eq(g, 0), mode=mode)
# assert len(f.maker.fgraph.toposort()) == 2, g
assert f([3, 3]) == 0
assert f([]) == 1
f = theano.function([x], T.eq(g, -1), mode=mode)
self.assert_eqs_const(f, 0)
assert f([3, 3]) == 0
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论