提交 94d648d9 authored 作者: notoraptor's avatar notoraptor

Add test (from issue #4834) to PR #5336.

上级 ddc08582
...@@ -5466,3 +5466,30 @@ def test_outputs_taps_check(): ...@@ -5466,3 +5466,30 @@ def test_outputs_taps_check():
outputs_info = {'initial': y, 'taps': [-1, -1]} outputs_info = {'initial': y, 'taps': [-1, -1]}
assert_raises(ValueError, theano.scan, f, x, outputs_info) assert_raises(ValueError, theano.scan, f, x, outputs_info)
print('done') print('done')
def test_default_value_broadcasted():
def floatx(X):
return numpy.asarray(X, dtype=theano.config.floatX)
def init_weights(shape, name):
return theano.shared(floatx(numpy.random.randn(*shape) * 0.1), name)
X = theano.tensor.matrix('X')
in_size = 2
out_size = 4
W_x = init_weights((in_size, out_size), "W_x")
def _active(x, pre_h):
x = theano.tensor.reshape(x, (1, in_size))
pre_h = theano.tensor.dot(x, W_x)
return pre_h
value, scan_updates = theano.scan(_active, sequences=X,
outputs_info=[theano.tensor.alloc(floatx(0.), 1, out_size)])
cost = theano.tensor.mean(value)
gW_x = theano.tensor.grad(cost, W_x)
updates = [(W_x, W_x - 0.1 * gW_x)]
f = theano.function([X], outputs=cost, updates=updates)
test = f(numpy.random.rand(10, in_size))
print(test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论