提交 68e959df authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Implemented comments from Pascal. Implemented some tests for local_useless_elemwise_comparison.

上级 bde718ca
......@@ -4226,7 +4226,7 @@ def local_flatten_lift(node):
# Copy over stacktrace from previous output node and from unary
# elementwise output node since if there was an error, it would
# probably have come from that operation.
copy_stack_trace(node.outputs + node.inputs[0], e)
copy_stack_trace(node.outputs + [node.inputs[0]], e)
return [e]
......@@ -4498,9 +4498,9 @@ def local_reshape_lift(node):
if e.type != node.outputs[0].type:
re = T.patternbroadcast(e, node.outputs[0].broadcastable)
# We assume that the broadcast op cannot fail. Thus, if the
# graph fails it must be due to previous UnaryElemwise op, and
# therefore we must copy its stacktrace over.
# Copy over stack trace.
# If the graph fails it is usually due to the fact that a dimension
# that should be broadcastable does not actually have length 1,
copy_stack_trace(e, re)
else:
re = e
......@@ -5200,7 +5200,7 @@ def local_sum_prod_mul_by_scalar(node):
# There are never errors in the negative op, thus
# we need only to copy over stacktrace from previous output node to
# the two new ops.
copy_stack_trace(node.outputs, s+ret)
copy_stack_trace(node.outputs, [s, ret])
return [ret]
......@@ -5217,7 +5217,7 @@ def local_elemwise_sub_zeros(node):
node.inputs[0] == node.inputs[1]):
res = T.zeros_like(node.inputs[0])
# Copy over stacktrace from previous output.
# Julian: Pascal, is this really necessary? Is there anyway zeros_like can ever fail?
# This could help for failures due to out-of-memory.
copy_stack_trace(node.outputs, res)
return [res]
......@@ -5394,9 +5394,14 @@ def local_useless_elemwise_comparison(node):
try:
cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True)
# Copy over stacktrace from previous output.
res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
copy_stack_trace(node.outputs, res)
if cst < 0:
return [T.zeros_like(node.inputs[0],
dtype=dtype, opt=True)]
return [res]
except NotScalarConstantError:
pass
return
......
......@@ -3566,7 +3566,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
......@@ -3662,7 +3661,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.eq(g, 0))
assert f([3, 3]) == 0
assert f([]) == 1
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.eq(g, -1))
self.assert_eqs_const(f, 0)
......@@ -3674,12 +3672,12 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.eq(g, 0))
assert (f([3, 3]) == 0).all()
assert (f([]) == 1).all()
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.eq(g, -1))
self.assert_eqs_const(f, 0, op=T.alloc)
assert (f([3, 3]) == 0).all()
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
......@@ -3727,6 +3725,22 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0)
def test_stacktrace(self):
mode = theano.compile.get_default_mode().including(
'local_useless_elemwise_comparison')
x = T.vector('x', dtype=config.floatX)
f = theano.function([x], T.gt(x, x), mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.le(x, x), mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
# Julian: I tried testing the stack trace for a bunch of different
# functions, including maximum and shapes, but other opts remove
# the stack traces in this case.
class Test_local_canonicalize_alloc(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论