提交 e2e65f54 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5443 from Thrandis/ccw

Added broadcast check in scan.
......@@ -403,6 +403,36 @@ class Scan(PureOp):
'by using dimshuffle or shape_padleft. '
)
def check_broadcast(v1, v2):
""" Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if (not hasattr(v1, 'broadcastable') and
not hasattr(v2, 'broadcastable')):
return
msg = ("The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using theano.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}.")
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(zip(v1.broadcastable[-size:],
v2.broadcastable[-size:])):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def format(var, as_var):
"""
This functions ensures that ``out`` has the same dtype as
......@@ -430,6 +460,7 @@ class Scan(PureOp):
argoffset = 0
for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs)):
check_broadcast(outer_seq, inner_seq)
new_inputs.append(format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs))
......
......@@ -5498,3 +5498,23 @@ def test_default_value_broadcasted():
updates = [(W_x, W_x - 0.1 * gW_x)]
f = theano.function([X], outputs=cost, updates=updates)
f(numpy.random.rand(10, in_size).astype(X.dtype))
class TestInconsistentBroadcast(unittest.TestCase):
def test_raise_error(self):
x = tensor.tensor3()
initial_x = tensor.constant(numpy.zeros((1, 10)))
y, updates = theano.scan(fn=lambda x, prev_x: x + prev_x,
sequences=x,
outputs_info=[dict(initial=initial_x)])
# Error, because the broadcast patterns are inconsistent.
with self.assertRaises(TypeError):
gs = tensor.grad(y.sum(), x)
# No error here, because the broadcast patterns are consistent.
initial_x = tensor.unbroadcast(initial_x, 0, 1)
y, updates = theano.scan(fn=lambda x, prev_x: x + prev_x,
sequences=x,
outputs_info=[dict(initial=initial_x)])
gs = tensor.grad(y.sum(), x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论