提交 68e959df authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Implemented comments from Pascal. Implemented some tests for local_useless_elemwise_comparison.

上级 bde718ca
...@@ -4226,7 +4226,7 @@ def local_flatten_lift(node): ...@@ -4226,7 +4226,7 @@ def local_flatten_lift(node):
# Copy over stacktrace from previous output node and from unary # Copy over stacktrace from previous output node and from unary
# elementwise output node since if there was an error, it would # elementwise output node since if there was an error, it would
# probably have come from that operation. # probably have come from that operation.
copy_stack_trace(node.outputs + node.inputs[0], e) copy_stack_trace(node.outputs + [node.inputs[0]], e)
return [e] return [e]
...@@ -4498,9 +4498,9 @@ def local_reshape_lift(node): ...@@ -4498,9 +4498,9 @@ def local_reshape_lift(node):
if e.type != node.outputs[0].type: if e.type != node.outputs[0].type:
re = T.patternbroadcast(e, node.outputs[0].broadcastable) re = T.patternbroadcast(e, node.outputs[0].broadcastable)
# We assume that the broadcast op cannot fail. Thus, if the # Copy over stack trace.
# graph fails it must be due to previous UnaryElemwise op, and # If the graph fails it is usually due to the fact that a dimension
# therefore we must copy its stacktrace over. # that should be broadcastable does not actually have length 1,
copy_stack_trace(e, re) copy_stack_trace(e, re)
else: else:
re = e re = e
...@@ -5200,7 +5200,7 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -5200,7 +5200,7 @@ def local_sum_prod_mul_by_scalar(node):
# There are never errors in the negative op, thus # There are never errors in the negative op, thus
# we need only to copy over stacktrace from previous output node to # we need only to copy over stacktrace from previous output node to
# the two new ops. # the two new ops.
copy_stack_trace(node.outputs, s+ret) copy_stack_trace(node.outputs, [s, ret])
return [ret] return [ret]
...@@ -5217,7 +5217,7 @@ def local_elemwise_sub_zeros(node): ...@@ -5217,7 +5217,7 @@ def local_elemwise_sub_zeros(node):
node.inputs[0] == node.inputs[1]): node.inputs[0] == node.inputs[1]):
res = T.zeros_like(node.inputs[0]) res = T.zeros_like(node.inputs[0])
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
# Julian: Pascal, is this really necessary? Is there anyway zeros_like can ever fail? # This could help for failures due to out-of-memory.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
...@@ -5394,9 +5394,14 @@ def local_useless_elemwise_comparison(node): ...@@ -5394,9 +5394,14 @@ def local_useless_elemwise_comparison(node):
try: try:
cst = get_scalar_constant_value(node.inputs[1], cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True) only_process_constants=True)
# Copy over stacktrace from previous output.
res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
copy_stack_trace(node.outputs, res)
if cst < 0: if cst < 0:
return [T.zeros_like(node.inputs[0], return [res]
dtype=dtype, opt=True)]
except NotScalarConstantError: except NotScalarConstantError:
pass pass
return return
......
...@@ -3566,7 +3566,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3566,7 +3566,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert isinstance(elem.inputs[0], T.TensorConstant), elem assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f): def assert_identity(self, f):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -3662,7 +3661,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3662,7 +3661,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.eq(g, 0)) f = theano.function([x], T.eq(g, 0))
assert f([3, 3]) == 0 assert f([3, 3]) == 0
assert f([]) == 1 assert f([]) == 1
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.eq(g, -1)) f = theano.function([x], T.eq(g, -1))
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
...@@ -3674,12 +3672,12 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3674,12 +3672,12 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.eq(g, 0)) f = theano.function([x], T.eq(g, 0))
assert (f([3, 3]) == 0).all() assert (f([3, 3]) == 0).all()
assert (f([]) == 1).all() assert (f([]) == 1).all()
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.eq(g, -1)) f = theano.function([x], T.eq(g, -1))
self.assert_eqs_const(f, 0, op=T.alloc) self.assert_eqs_const(f, 0, op=T.alloc)
assert (f([3, 3]) == 0).all() assert (f([3, 3]) == 0).all()
def test_and(self): def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
...@@ -3727,6 +3725,22 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3727,6 +3725,22 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
def test_stacktrace(self):
mode = theano.compile.get_default_mode().including(
'local_useless_elemwise_comparison')
x = T.vector('x', dtype=config.floatX)
f = theano.function([x], T.gt(x, x), mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.le(x, x), mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
# Julian: I tried testing the stack trace for a bunch of different
# functions, including maximum and shapes, but other opts remove
# the stack traces in this case.
class Test_local_canonicalize_alloc(unittest.TestCase): class Test_local_canonicalize_alloc(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed()) self.rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论