提交 0e2d7a79 authored 作者: Frederic Bastien's avatar Frederic Bastien

Alloc OpRemove to accept a class to remove. Tested by test_gradien.py:test_grad_clip

上级 fd05e3a3
......@@ -1542,9 +1542,11 @@ class OpRemove(LocalOptimizer):
return [self.op]
def transform(self, node):
if node.op != self.op:
return False
return node.inputs
if inspect.isclass(self.op):
if isinstance(node.op, self.op):
return node.inputs
elif node.op == self.op:
return node.inputs
def __str__(self):
return "%s(x) -> x" % (self.op)
......
......@@ -7502,11 +7502,9 @@ register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
name='remove_disconnected_grad')
@register_canonicalize
@gof.local_optimizer([theano.gradient.GradClip])
def local_grad_clip(node):
if isinstance(node.op, theano.gradient.GradClip):
return node.inputs
register_canonicalize(gof.OpRemove(theano.gradient.GradClip),
'fast_compile', 'fast_run',
name='remove_grad_clip')
@register_useless
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论