提交 0f82ee2d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix test by implementing the case with MakeVector

上级 7fc086d9
......@@ -3568,7 +3568,7 @@ def local_join_empty(node):
@register_specialize
@register_canonicalize
@register_canonicalize('fast_compile')
@gof.local_optimizer([T.Join])
def local_join_make_vector(node):
"""Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
......@@ -4858,6 +4858,9 @@ def local_useless_elemwise_comparison(node):
elif isinstance(node.op, T.Join):
return all(v.owner and
investigate(v.owner) for v in node.inputs[1:])
elif isinstance(node.op, MakeVector):
return all(v.owner and
investigate(v.owner) for v in node.inputs)
if (isinstance(node.op.scalar_op, scalar.EQ) and
node.inputs[0].owner and
......
......@@ -3409,7 +3409,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
sequences=[X],
non_sequences=None)
Z = X_sum + Y
theano.printing.debugprint(Z)
#theano.printing.debugprint(Z)
# here is the output for the debug print:
"""
Elemwise{add,no_inplace} [id A] ''
......@@ -3436,7 +3436,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode)
theano.printing.debugprint(f, print_type=True)
#theano.printing.debugprint(f, print_type=True)
# here is the output for the debug print:
"""
Elemwise{Add}[(0, 0)] [id A] <TensorType(float64, vector)> '' 7
......@@ -3465,14 +3465,19 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
> |X[t] [id O] <TensorType(float64, vector)> -> [id E]
"""
def assert_eqs_const(self, f, val):
def assert_eqs_const(self, f, val, op=deep_copy_op):
topo = f.maker.fgraph.toposort()
elem = topo[0]
assert len(topo) == 1, topo
assert elem.op == deep_copy_op, elem.op
assert elem.op == op, elem.op
if op == deep_copy_op:
assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
else:
assert len(elem.inputs) == 2, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
......@@ -3561,12 +3566,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
'local_track_shape_i',
'local_subtensor_make_vector')
for g in [x.shape[0],
Shape_i(0)(x),
join(0,
x.shape[0:], # todo test reshape, dimshuffle
x.shape[0:1])]:
Shape_i(0)(x)]:
f = theano.function([x], T.eq(g, 0), mode=mode)
# assert len(f.maker.fgraph.toposort()) == 2, g
assert f([3, 3]) == 0
assert f([]) == 1
......@@ -3574,6 +3575,17 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0)
assert f([3, 3]) == 0
g = join(0,
x.shape[0:], # todo test reshape, dimshuffle
x.shape[0:1])
f = theano.function([x], T.eq(g, 0), mode=mode)
assert (f([3, 3]) == 0).all()
assert (f([]) == 1).all()
f = theano.function([x], T.eq(g, -1), mode=mode)
self.assert_eqs_const(f, 0, op=T.alloc)
assert (f([3, 3]) == 0).all()
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论