提交 ce2e1562 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5336 from lamblin/fix_scan_bcastable

Use a default value compatible with broadcastable of type
...@@ -773,8 +773,11 @@ class Scan(PureOp): ...@@ -773,8 +773,11 @@ class Scan(PureOp):
# function exectution. Also, since an update is # function exectution. Also, since an update is
# defined, a default value must also be (this is # defined, a default value must also be (this is
# verified by DebugMode). Use an array of size 0 but # verified by DebugMode). Use an array of size 0 but
# the right ndim and dtype. # the right ndim and dtype (use a shape of 1 on
default_val = numpy.zeros([0] * inp.ndim, # broadcastable dimensions, 0 on the others).
default_shape = [1 if _b else 0
for _b in inp.broadcastable]
default_val = numpy.zeros(default_shape,
dtype=inp.dtype) dtype=inp.dtype)
wrapped_inp = In(variable=inp, value=default_val, wrapped_inp = In(variable=inp, value=default_val,
update=self.outputs[output_idx]) update=self.outputs[output_idx])
......
...@@ -5466,3 +5466,29 @@ def test_outputs_taps_check(): ...@@ -5466,3 +5466,29 @@ def test_outputs_taps_check():
outputs_info = {'initial': y, 'taps': [-1, -1]} outputs_info = {'initial': y, 'taps': [-1, -1]}
assert_raises(ValueError, theano.scan, f, x, outputs_info) assert_raises(ValueError, theano.scan, f, x, outputs_info)
print('done') print('done')
def test_default_value_broadcasted():
def floatx(X):
return numpy.asarray(X, dtype=theano.config.floatX)
def init_weights(shape, name):
return theano.shared(floatx(numpy.random.randn(*shape) * 0.1), name)
X = theano.tensor.matrix('X')
in_size = 2
out_size = 4
W_x = init_weights((in_size, out_size), "W_x")
def _active(x, pre_h):
x = theano.tensor.reshape(x, (1, in_size))
pre_h = theano.tensor.dot(x, W_x)
return pre_h
value, scan_updates = theano.scan(_active, sequences=X,
outputs_info=[theano.tensor.alloc(floatx(0.), 1, out_size)])
cost = theano.tensor.mean(value)
gW_x = theano.tensor.grad(cost, W_x)
updates = [(W_x, W_x - 0.1 * gW_x)]
f = theano.function([X], outputs=cost, updates=updates)
f(numpy.random.rand(10, in_size).astype(X.dtype))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论