提交 a0a112b4 authored 作者: Frederic Bastien's avatar Frederic Bastien

Remove from optimized graph UndefinedGrad node

上级 9d5356d6
......@@ -7502,6 +7502,9 @@ register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
'fast_compile', 'fast_run',
name='remove_disconnected_grad')
register_canonicalize(gof.OpRemove(theano.gradient.undefined_grad_),
'fast_compile', 'fast_run',
name='remove_undefined_grad')
register_canonicalize(gof.OpRemove(theano.gradient.GradClip),
'fast_compile', 'fast_run',
......
......@@ -16,6 +16,7 @@ from theano.tests import unittest_tools as utt
from theano import gradient
from theano import config
from theano.gof.null_type import NullType
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
one = theano.tensor.as_tensor_variable(1.)
......@@ -802,5 +803,20 @@ def test_grad_scale():
assert np.allclose(out, (8, 4))
def test_undefined_grad_opt():
# Make sure that undefined grad get removed in optimized graph.
random = RandomStreams(np.random.randint(1, 2147462579))
pvals = theano.shared(np.random.rand(10, 20).astype(theano.config.floatX))
pvals = pvals / pvals.sum(axis=1)
pvals = gradient.zero_grad(pvals)
samples = random.multinomial(pvals=pvals, n=1)
samples = theano.tensor.cast(samples, pvals.dtype)
samples = gradient.zero_grad(samples)
cost = theano.tensor.sum(samples + pvals)
grad = theano.tensor.grad(cost, samples)
f = theano.function([], grad)
theano.printing.debugprint(f)
assert not any([isinstance(node.op, gradient.UndefinedGrad) for node in f.maker.fgraph.apply_nodes])
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论