提交 0f82ee2d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix test by implementing the case with MakeVector

上级 7fc086d9
...@@ -3568,7 +3568,7 @@ def local_join_empty(node): ...@@ -3568,7 +3568,7 @@ def local_join_empty(node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize('fast_compile')
@gof.local_optimizer([T.Join]) @gof.local_optimizer([T.Join])
def local_join_make_vector(node): def local_join_make_vector(node):
"""Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...) """Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
...@@ -4858,6 +4858,9 @@ def local_useless_elemwise_comparison(node): ...@@ -4858,6 +4858,9 @@ def local_useless_elemwise_comparison(node):
elif isinstance(node.op, T.Join): elif isinstance(node.op, T.Join):
return all(v.owner and return all(v.owner and
investigate(v.owner) for v in node.inputs[1:]) investigate(v.owner) for v in node.inputs[1:])
elif isinstance(node.op, MakeVector):
return all(v.owner and
investigate(v.owner) for v in node.inputs)
if (isinstance(node.op.scalar_op, scalar.EQ) and if (isinstance(node.op.scalar_op, scalar.EQ) and
node.inputs[0].owner and node.inputs[0].owner and
......
...@@ -3409,7 +3409,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3409,7 +3409,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
sequences=[X], sequences=[X],
non_sequences=None) non_sequences=None)
Z = X_sum + Y Z = X_sum + Y
theano.printing.debugprint(Z) #theano.printing.debugprint(Z)
# here is the output for the debug print: # here is the output for the debug print:
""" """
Elemwise{add,no_inplace} [id A] '' Elemwise{add,no_inplace} [id A] ''
...@@ -3436,7 +3436,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3436,7 +3436,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
mode = theano.compile.get_default_mode().excluding('fusion') mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode) f = theano.function([X, Y], Z, mode=mode)
theano.printing.debugprint(f, print_type=True) #theano.printing.debugprint(f, print_type=True)
# here is the output for the debug print: # here is the output for the debug print:
""" """
Elemwise{Add}[(0, 0)] [id A] <TensorType(float64, vector)> '' 7 Elemwise{Add}[(0, 0)] [id A] <TensorType(float64, vector)> '' 7
...@@ -3465,14 +3465,19 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3465,14 +3465,19 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
> |X[t] [id O] <TensorType(float64, vector)> -> [id E] > |X[t] [id O] <TensorType(float64, vector)> -> [id E]
""" """
def assert_eqs_const(self, f, val): def assert_eqs_const(self, f, val, op=deep_copy_op):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
elem = topo[0] elem = topo[0]
assert len(topo) == 1, topo assert len(topo) == 1, topo
assert elem.op == deep_copy_op, elem.op assert elem.op == op, elem.op
if op == deep_copy_op:
assert len(elem.inputs) == 1, elem.inputs assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val assert T.extract_constant(elem.inputs[0]) == val, val
else:
assert len(elem.inputs) == 2, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f): def assert_identity(self, f):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -3561,12 +3566,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3561,12 +3566,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
'local_track_shape_i', 'local_track_shape_i',
'local_subtensor_make_vector') 'local_subtensor_make_vector')
for g in [x.shape[0], for g in [x.shape[0],
Shape_i(0)(x), Shape_i(0)(x)]:
join(0,
x.shape[0:], # todo test reshape, dimshuffle
x.shape[0:1])]:
f = theano.function([x], T.eq(g, 0), mode=mode) f = theano.function([x], T.eq(g, 0), mode=mode)
# assert len(f.maker.fgraph.toposort()) == 2, g
assert f([3, 3]) == 0 assert f([3, 3]) == 0
assert f([]) == 1 assert f([]) == 1
...@@ -3574,6 +3575,17 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3574,6 +3575,17 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
assert f([3, 3]) == 0 assert f([3, 3]) == 0
g = join(0,
x.shape[0:], # todo test reshape, dimshuffle
x.shape[0:1])
f = theano.function([x], T.eq(g, 0), mode=mode)
assert (f([3, 3]) == 0).all()
assert (f([]) == 1).all()
f = theano.function([x], T.eq(g, -1), mode=mode)
self.assert_eqs_const(f, 0, op=T.alloc)
assert (f([3, 3]) == 0).all()
def test_and(self): def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论