提交 ae614647 authored 作者: abergeron's avatar abergeron

Merge pull request #4089 from nouiz/gemv_broadcast

Fix crash with gemv and some broadcast pattern
......@@ -431,14 +431,6 @@ class Gemv(Op):
raise TypeError('gemv requires vector for x', x.type)
if y.ndim != 1:
raise TypeError('gemv requires vector for y', y.type)
if y.broadcastable[0] != A.broadcastable[0]:
raise TypeError('broadcastable mismatch between y and A',
(y.type, A.type))
# The following is not grounds for error because as long as
# sizes are 1 at time of perform() there is no problem
# if x.broadcastable[0] != A.broadcastable[1]:
# raise TypeError('broadcastable mismatch between x and A',
# (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
......
......@@ -1235,6 +1235,35 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
assert numpy.allclose(v2.get_value(),
numpy.dot(v1.get_value(), m.get_value()) + v2_orig)
def test_gemv_broadcast(self):
''' test gemv with some broadcasted input '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)),
dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(1,)), dtype='float32')
v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=(1, 2)),
dtype='float32'),
broadcastable=(True, False))
o = theano.dot(m, v1)
f = theano.function([], o + v2, mode=mode_blas_opt)
# Assert they produce the same output
assert numpy.allclose(
f(),
numpy.dot(m.get_value(), v1.get_value()) + v2.get_value())
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo) == 1
# call gemv directly for mixed broadcast pattern.
o = theano.tensor.blas.gemv_no_inplace(v2, 0.5, m, v1, 0.25)
f = theano.function([], o, mode=mode_blas_opt)
assert numpy.allclose(
f(),
0.5*numpy.dot(m.get_value(), v1.get_value()) + 0.25*v2.get_value())
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo) == 1
def test_gemv_dimensions(self):
A = T.matrix('A')
x, y = T.vectors('x', 'y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论