提交 46000c76 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix gh-5907. Make OpFromGraph.grad handle compute_test_value correctly.

上级 6187a1fa
...@@ -271,6 +271,7 @@ class OpFromGraph(gof.Op): ...@@ -271,6 +271,7 @@ class OpFromGraph(gof.Op):
is_inline = self.is_inline is_inline = self.is_inline
return '%(name)s{inline=%(is_inline)s}' % locals() return '%(name)s{inline=%(is_inline)s}' % locals()
@theano.configparser.change_flags(compute_test_value='off')
def _recompute_grad_op(self): def _recompute_grad_op(self):
''' '''
converts self._grad_op from user supplied form to type(self) instance converts self._grad_op from user supplied form to type(self) instance
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division ...@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division
from functools import partial from functools import partial
import numpy as np import numpy as np
import theano
from theano import config, shared from theano import config, shared
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
...@@ -313,3 +314,14 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -313,3 +314,14 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
[np.ones([3, 4], dtype=config.floatX), [np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
OpFromGraph) OpFromGraph)
@theano.configparser.change_flags(compute_test_value='raise')
def test_compute_test_value(self):
x = T.scalar('x')
x.tag.test_value = np.array(1.)
op = OpFromGraph([x], [x ** 3])
y = T.scalar('y')
y.tag.test_value = np.array(1.)
f = op(y)
grad_f = T.grad(f, y)
assert grad_f.tag.test_value is not None
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论