提交 09f4ff8e authored 作者: Frederic Bastien's avatar Frederic Bastien

added test to gemv when they are inplace.

上级 a92d2350
...@@ -687,27 +687,55 @@ def test_dot_mv(): ...@@ -687,27 +687,55 @@ def test_dot_mv():
def test_gemv1(): def test_gemv1():
''' test vector1+dot(matrix,vector2) ''' ''' test vector1+dot(matrix,vector2) '''
v1 = theano.shared( numpy.array(numpy.random.rand(2) , dtype='float32')) v1 = theano.shared( numpy.array(numpy.random.rand(2) , dtype='float32'))
v2 = theano.shared( numpy.array(numpy.random.rand(2) , dtype='float32')) v2_orig = numpy.array(numpy.random.rand(2), dtype='float32')
v2 = theano.shared( v2_orig )
m = theano.shared( numpy.array(numpy.random.rand(2,2), dtype='float32')) m = theano.shared( numpy.array(numpy.random.rand(2,2), dtype='float32'))
f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt) f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt)
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(m.value,v1.value)+v2.value) assert numpy.allclose(f(), numpy.dot(m.value,v1.value)+v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, Gemv)
assert topo[0].op.inplace==False
assert sum([isinstance(node.op, Gemv) for node in #test the inplace version
f.maker.env.toposort() ]) == 1 f = theano.function([], [], updates={v2:v2+theano.dot(m,v1)}
, mode = mode_blas_opt)
# Assert they produce the same output
f()
assert numpy.allclose(v2.value, numpy.dot(m.value,v1.value)+v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, Gemv)
assert topo[0].op.inplace==True
def test_gemv2(): def test_gemv2():
''' test vector1+dot(vector2,matrix) ''' ''' test vector1+dot(vector2,matrix) '''
v1 = theano.shared( numpy.array(numpy.random.rand(2) , dtype='float32')) v1 = theano.shared( numpy.array(numpy.random.rand(2) , dtype='float32'))
v2 = theano.shared( numpy.array(numpy.random.rand(2) , dtype='float32')) v2_orig = numpy.array(numpy.random.rand(2), dtype='float32')
v2 = theano.shared( v2_orig )
m = theano.shared( numpy.array(numpy.random.rand(2,2), dtype='float32')) m = theano.shared( numpy.array(numpy.random.rand(2,2), dtype='float32'))
f = theano.function([], v2+theano.dot(v1,m), mode = mode_blas_opt) f = theano.function([], v2+theano.dot(v1,m), mode = mode_blas_opt)
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(v1.value,m.value)+v2.value) assert numpy.allclose(f(), numpy.dot(v1.value,m.value)+v2.value)
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
assert topo[-1].op.inplace==False
#test the inplace version
f = theano.function([], [], updates={v2:v2+theano.dot(v1,m)}
, mode = mode_blas_opt)
# Assert they produce the same output
f()
assert numpy.allclose(v2.value, numpy.dot(v1.value, m.value)+v2_orig)
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
assert topo[0].op.inplace==True
assert sum([isinstance(node.op, Gemv) for node in
f.maker.env.toposort() ]) == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论