提交 bbda20ee authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2996 from lamblin/scan_test_values

Scan test values
......@@ -64,6 +64,16 @@ def safe_new(x, tag='', dtype=None):
else:
nw_x = x.type()
nw_x.name = nw_name
if theano.config.compute_test_value != 'off':
# Copy test value, cast it if necessary
try:
x_test_value = gof.op.get_test_value(x)
except AttributeError:
# There is no test value
pass
else:
# This clause is executed if no exception was raised
nw_x.tag.test_value = nw_x.type.filter(x_test_value)
return nw_x
else:
try:
......@@ -73,11 +83,13 @@ def safe_new(x, tag='', dtype=None):
# want to avoid the convoluted logic that checks for cuda
# ndarrays
pass
# Cast x if needed. If x has a test value, this will also cast it.
if dtype and x.dtype != dtype:
x = x.astype(dtype)
nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# between test values, due to inplace operations for instance. This may
......
......@@ -4983,12 +4983,34 @@ def test_compute_test_value_grad():
)
loss = result_mi[-1]
grad = tensor.grad(loss, W_flat)
tensor.grad(loss, W_flat)
finally:
theano.config.compute_test_value = old_compute_test_val
theano.config.exception_verbosity = old_exception_verbosity
def test_compute_test_value_grad_cast():
# Test for test values when variables have to be casted
# Reported by Daniel Renshaw at
# https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
try:
h = tensor.matrix('h')
h.tag.test_value = numpy.array([[1, 2, 3, 4], [5, 6, 7, 8]],
dtype=numpy.float64)
w = theano.shared(numpy.random.randn(4, 3).astype('float64'), name='w')
outputs, _ = theano.scan(lambda i, h, w: (theano.dot(h[i], w), i),
outputs_info=[None, 0L], non_sequences=[h, w],
n_steps=3)
theano.grad(outputs[0].sum(), w)
finally:
theano.config.compute_test_value = backup
def test_constant_folding_n_steps():
# The following code used to crash at revision 2060b8f, in the constant
# folding optimization step.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论