提交 e50c2ca6 authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Included changes in tensor/opt.py for #2454. These were previously lost due to…

Included changes in tensor/opt.py for #2454. These were previously lost due to problems committing to GitHub.
上级 9efff971
...@@ -1881,6 +1881,8 @@ def _is_zero(x): ...@@ -1881,6 +1881,8 @@ def _is_zero(x):
class ConsiderConstant(ViewOp): class ConsiderConstant(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs] return [g_out.zeros_like(g_out) for g_out in g_outs]
consider_constant_ = ConsiderConstant() consider_constant_ = ConsiderConstant()
...@@ -1916,6 +1918,7 @@ class ZeroGrad(ViewOp): ...@@ -1916,6 +1918,7 @@ class ZeroGrad(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs] return [g_out.zeros_like(g_out) for g_out in g_outs]
zero_grad_ = ZeroGrad() zero_grad_ = ZeroGrad()
......
...@@ -5587,11 +5587,17 @@ else: ...@@ -5587,11 +5587,17 @@ else:
# # Remove consider_constant # # # Remove consider_constant #
# ############################ # ############################
# Although the op just returns its input, it should be removed from # Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad
# the graph to make sure all possible optimizations can be applied. # just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_), register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run', name='remove_consider_constant') 'fast_compile', 'fast_run', name='remove_consider_constant')
register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_),
'fast_compile', 'fast_run', name='remove_zero_grad')
register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
'fast_compile', 'fast_run', name='remove_disconnected_grad')
@register_canonicalize @register_canonicalize
@gof.local_optimizer([theano.gradient.GradClip]) @gof.local_optimizer([theano.gradient.GradClip])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论