提交 36b45acb authored 作者: Frederic Bastien's avatar Frederic Bastien

copy stack trace in new opt.

上级 86842018
...@@ -3,6 +3,7 @@ import numpy ...@@ -3,6 +3,7 @@ import numpy
import theano import theano
from theano import Apply, Op from theano import Apply, Op
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import copy_stack_trace
from theano.tensor import as_tensor_variable, TensorType from theano.tensor import as_tensor_variable, TensorType
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.opt import register_specialize_device from theano.tensor.opt import register_specialize_device
...@@ -630,7 +631,9 @@ def local_abstract_batch_norm_train(node): ...@@ -630,7 +631,9 @@ def local_abstract_batch_norm_train(node):
results = [T.patternbroadcast(r, r_orig.broadcastable) results = [T.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)] for (r, r_orig) in zip(results, node.outputs)]
# TODO copy_stack_trace? for var in theano.gof.graph.variables(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return results return results
...@@ -663,7 +666,9 @@ def local_abstract_batch_norm_train_grad(node): ...@@ -663,7 +666,9 @@ def local_abstract_batch_norm_train_grad(node):
results = [T.patternbroadcast(r, r_orig.broadcastable) results = [T.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)] for (r, r_orig) in zip(results, node.outputs)]
# TODO copy_stack_trace? for var in theano.gof.graph.variables(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return results return results
...@@ -685,7 +690,9 @@ def local_abstract_batch_norm_inference(node): ...@@ -685,7 +690,9 @@ def local_abstract_batch_norm_inference(node):
result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias
result = T.patternbroadcast(result, node.outputs[0].broadcastable) result = T.patternbroadcast(result, node.outputs[0].broadcastable)
# TODO copy_stack_trace? for var in theano.gof.graph.variables(node.inputs, [result]):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return [result] return [result]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论