提交 9d5356d6 authored 作者: Frederic Bastien's avatar Frederic Bastien

Remove GradScale from optimized graph

上级 0e2d7a79
......@@ -7490,6 +7490,7 @@ def local_useless_composite(node):
# Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad
# just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied.
# TODO: Add in useless too!
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run',
name='remove_consider_constant')
......@@ -7506,6 +7507,10 @@ register_canonicalize(gof.OpRemove(theano.gradient.GradClip),
'fast_compile', 'fast_run',
name='remove_grad_clip')
register_canonicalize(gof.OpRemove(theano.gradient.GradScale),
'fast_compile', 'fast_run',
name='remove_grad_scale')
@register_useless
@register_canonicalize
......
......@@ -784,5 +784,23 @@ def test_grad_clip():
assert np.allclose(out, (1, 4))
assert not np.allclose(out[0], out[1])
def test_grad_scale():
x = theano.tensor.scalar()
z = theano.tensor.grad(gradient.grad_scale(x, 2)**2, x)
z2 = theano.tensor.grad(x**2, x)
f = theano.function([x], outputs=[z, z2])
if theano.config.mode != "FAST_COMPILE":
topo = f.maker.fgraph.toposort()
assert not any([isinstance(node.op, gradient.GradScale)
for node in topo])
out = f(2.)
assert np.allclose(out, (8, 4))
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论