提交 2ebc24ff authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Continued work on #3018. Fixed stack trace copy over for additional…

Continued work on #3018. Fixed stack trace copy over for additional optimizations, fixed errors in previous commit and inclduded addditional tests.
上级 daf7ea76
...@@ -77,6 +77,7 @@ def add_tag_trace(thing, user_line=1): ...@@ -77,6 +77,7 @@ def add_tag_trace(thing, user_line=1):
if limit == -1: if limit == -1:
limit = None limit = None
tr = simple_extract_stack(limit=limit)[:-1] tr = simple_extract_stack(limit=limit)[:-1]
# Different python version use different sementic for # Different python version use different sementic for
# limit. python 2.7 include the call to extrack_stack. The -1 get # limit. python 2.7 include the call to extrack_stack. The -1 get
# rid of it. # rid of it.
...@@ -93,7 +94,11 @@ def add_tag_trace(thing, user_line=1): ...@@ -93,7 +94,11 @@ def add_tag_trace(thing, user_line=1):
"theano/sparse/", "theano\\sparse\\", "theano/sparse/", "theano\\sparse\\",
"theano/typed_list/", "theano\\typed_list\\", "theano/typed_list/", "theano\\typed_list\\",
]: ]:
if p in file_path: # Julian: I added the 'tests' exception together with Arnaud.
# Otherwise, we'd lose the stack trace during in our test cases
# (e.g. in test_opt.py). We're not sure this is the right way to
# do it though.
if p in file_path and 'tests' not in file_path:
tr = tr[:-1] tr = tr[:-1]
rm = True rm = True
break break
......
...@@ -91,8 +91,8 @@ def copy_stack_trace(from_var, to_var): ...@@ -91,8 +91,8 @@ def copy_stack_trace(from_var, to_var):
tr += getattr(v.tag, 'trace', []) tr += getattr(v.tag, 'trace', [])
else: else:
# If from_var is not a list, it must be a single tensor # If from_var is not a list, it must be a single tensor variable,
# variable, so just store that particular stack trace # so just store that particular stack trace
tr = getattr(from_var.tag, 'trace', []) tr = getattr(from_var.tag, 'trace', [])
# Copy over stack traces to to_var # Copy over stack traces to to_var
...@@ -2565,7 +2565,7 @@ def local_subtensor_lift(node): ...@@ -2565,7 +2565,7 @@ def local_subtensor_lift(node):
ret = u.owner.op(x_idx) ret = u.owner.op(x_idx)
# Copy over previous output stacktrace # Copy over previous output stacktrace
# and stacktrace from previous unary operation # and stacktrace from previous unary operation
copy_stack_trace([node.outputs, node.inputs[0]], ret) copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret] return [ret]
if isinstance(u.owner.op, T.Elemwise): if isinstance(u.owner.op, T.Elemwise):
...@@ -2574,7 +2574,14 @@ def local_subtensor_lift(node): ...@@ -2574,7 +2574,14 @@ def local_subtensor_lift(node):
# There is no broadcastable in the inputs # There is no broadcastable in the inputs
idx = node.inputs[1:] idx = node.inputs[1:]
new_inputs = [node.op(i, *idx) for i in u.owner.inputs] new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
return [u.owner.op(*new_inputs)] # Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
elif all([sum(i.type.broadcastable) in [i.ndim, 0] elif all([sum(i.type.broadcastable) in [i.ndim, 0]
for i in u.owner.inputs]): for i in u.owner.inputs]):
# There is no broadcastable in the inputs or it is scalar # There is no broadcastable in the inputs or it is scalar
...@@ -2591,7 +2598,15 @@ def local_subtensor_lift(node): ...@@ -2591,7 +2598,15 @@ def local_subtensor_lift(node):
else: else:
new_inputs.append( new_inputs.append(
i.dimshuffle(['x'] * node.outputs[0].ndim)) i.dimshuffle(['x'] * node.outputs[0].ndim))
return [u.owner.op(*new_inputs)]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
if isinstance(u.owner.op, T.Rebroadcast): if isinstance(u.owner.op, T.Rebroadcast):
# make sure that Rebroadcast has only 1 input # make sure that Rebroadcast has only 1 input
...@@ -2617,7 +2632,13 @@ def local_subtensor_lift(node): ...@@ -2617,7 +2632,13 @@ def local_subtensor_lift(node):
j += 1 j += 1
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], subt_x)
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x) rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
return [rbcast_subt_x] return [rbcast_subt_x]
...@@ -2809,11 +2830,18 @@ def local_subtensor_merge(node): ...@@ -2809,11 +2830,18 @@ def local_subtensor_merge(node):
merged_slices = make_constant(merged_slices) merged_slices = make_constant(merged_slices)
subtens = Subtensor(merged_slices) subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse( sl_ins = Subtensor.collapse(
merged_slices, merged_slices,
lambda x: isinstance(x, T.Variable)) lambda x: isinstance(x, T.Variable))
# Do not call make_node for test_value # Do not call make_node for test_value
out = subtens(x, *sl_ins) out = subtens(x, *sl_ins)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
copy_stack_trace([node.outputs[0], node.inputs[0]], out)
return [out] return [out]
...@@ -2821,6 +2849,7 @@ def local_subtensor_merge(node): ...@@ -2821,6 +2849,7 @@ def local_subtensor_merge(node):
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
def local_subtensor_of_alloc(node): def local_subtensor_of_alloc(node):
#TODO Julian: Document this better!
"""alloc[x:y] -> alloc""" """alloc[x:y] -> alloc"""
if not isinstance(node.op, Subtensor): if not isinstance(node.op, Subtensor):
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论