提交 11e1a823 authored 作者: Frederic Bastien's avatar Frederic Bastien

Preserv the nan_guard_mode_check

上级 578ad4fd
......@@ -199,7 +199,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
return fgraph, list(map(SymbolicOutput, updates))
std_fgraph.features = [gof.toolbox.PreserveNames]
std_fgraph.features = [gof.toolbox.PreserveVariableAttributes]
class AliasedMemoryError(Exception):
......
......@@ -323,7 +323,7 @@ class NanGuardMode(Mode):
fn()
outputs = fn.outputs
for x, var in zip(outputs, node.outputs):
if getattr(var, 'nan_guard_mode_check', True):
if getattr(var.tag, 'nan_guard_mode_check', True):
do_check_on(x[0], node, fn, False)
wrap_linker = theano.gof.WrapLinker([theano.gof.OpWiseCLinker()],
......
......@@ -454,11 +454,17 @@ class PrintListener(Feature):
node, i, r, new_r))
class PreserveNames(Feature):
class PreserveVariableAttributes(Feature):
"""
This preserve some variables attributes and tag during optimization.
"""
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None:
new_r.name = r.name
if getattr(r.tag, 'nan_guard_mode_check', False) and getattr(
new_r.tag, 'nan_guard_mode_check', False) is False:
new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check
class NoOutputFromInplace(Feature):
......
......@@ -621,7 +621,7 @@ def expand_empty(tensor_var, size):
empty = tensor.AllocEmpty(tensor_var.dtype)(*new_shape)
ret = tensor.set_subtensor(empty[:shapes[0]], tensor_var)
ret.nan_guard_mode_check = False
ret.tag.nan_guard_mode_check = False
return ret
......
......@@ -6242,7 +6242,9 @@ class AllocEmpty(gof.Op):
# We can't reuse filter_checks_isfinite as by default it is
# False and it is set to true only in DebugMode.
# We can't set it in the type as other make_node can reuse the type.
output.nan_guard_mode_check = False
# We can't set it in the variable as it isn't copied when we copy
# the variale. So we set it in the tag.
output.tag.nan_guard_mode_check = False
return Apply(self, shape, [output])
def perform(self, node, inputs, out_):
......
......@@ -3076,9 +3076,6 @@ def local_inplace_setsubtensor(node):
set_instead_of_inc=node.op.set_instead_of_inc,
destroyhandler_tolerate_aliased=dta)
new_node = new_op(*node.inputs)
# Keep the information needed for NanGuardMode.
new_node.nan_guard_mode_check = getattr(
node.op, 'nan_guard_mode_check', True)
# Copy stacktrace from original outputs to new outputs.
# This is sensible, because the new operation is the
# same as the old one, but now with different attributes.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论