提交 36b45acb authored 作者: Frederic Bastien's avatar Frederic Bastien

copy stack trace in new opt.

上级 86842018
......@@ -3,6 +3,7 @@ import numpy
import theano
from theano import Apply, Op
from theano.gof import local_optimizer
from theano.gof.opt import copy_stack_trace
from theano.tensor import as_tensor_variable, TensorType
from theano.tensor import basic as T
from theano.tensor.opt import register_specialize_device
......@@ -630,7 +631,9 @@ def local_abstract_batch_norm_train(node):
results = [T.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)]
# TODO copy_stack_trace?
for var in theano.gof.graph.variables(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return results
......@@ -663,7 +666,9 @@ def local_abstract_batch_norm_train_grad(node):
results = [T.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)]
# TODO copy_stack_trace?
for var in theano.gof.graph.variables(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return results
......@@ -685,7 +690,9 @@ def local_abstract_batch_norm_inference(node):
result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias
result = T.patternbroadcast(result, node.outputs[0].broadcastable)
# TODO copy_stack_trace?
for var in theano.gof.graph.variables(node.inputs, [result]):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return [result]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论