提交 771cdc3d authored 作者: Frederic's avatar Frederic

test dot(1d, 1d) -> gemv.

上级 ae13d891
......@@ -873,6 +873,20 @@ def test_dot_w_self():
###############################################################################
class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
def test_dot_vv(self):
''' Currently we generate a gemv for that case'''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
w = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
f = theano.function([], theano.dot(v, w), mode=mode_blas_opt)
# Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot)
self.assertFunctionContains1(f, Gemv(False))
# Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(v.get_value(), w.get_value()))
def test_dot_vm(self):
''' Test vector dot matrix '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论