提交 ba320eeb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Small refactoring of tests.

上级 dcfe260e
......@@ -699,69 +699,69 @@ def test_dot_mv():
assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 1
def test_gemv1():
''' test vector1+dot(matrix,vector2) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(2,)), dtype='float32')
v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=(2,2)), dtype='float32'))
class TestGemv(TestCase):
def test_gemv1(self):
''' test vector1+dot(matrix,vector2) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(2,)), dtype='float32')
v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=(2,2)), dtype='float32'))
f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt)
f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt)
# Assert they produce the same output
assert numpy.allclose(f(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, Gemv)
assert topo[0].op.inplace==False
# Assert they produce the same output
assert numpy.allclose(f(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, Gemv)
assert topo[0].op.inplace==False
#test the inplace version
f = theano.function([], [], updates={v2:v2+theano.dot(m,v1)}
, mode = mode_blas_opt)
#test the inplace version
f = theano.function([], [], updates={v2:v2+theano.dot(m,v1)}
, mode = mode_blas_opt)
# Assert they produce the same output
f()
assert numpy.allclose(v2.get_value(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, Gemv)
if config.mode != 'FAST_COMPILE':
assert topo[0].op.inplace==True
# Assert they produce the same output
f()
assert numpy.allclose(v2.get_value(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, Gemv)
if config.mode != 'FAST_COMPILE':
assert topo[0].op.inplace==True
def test_gemv2():
''' test vector1+dot(vector2,matrix) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(2,)), dtype='float32')
v2 = theano.shared(v2_orig )
m = theano.shared(numpy.array(rng.uniform(size=(2,2)), dtype='float32'))
def test_gemv2(self):
''' test vector1+dot(vector2,matrix) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(2,)), dtype='float32')
v2 = theano.shared(v2_orig )
m = theano.shared(numpy.array(rng.uniform(size=(2,2)), dtype='float32'))
f = theano.function([], v2+theano.dot(v1,m), mode = mode_blas_opt)
f = theano.function([], v2+theano.dot(v1,m), mode = mode_blas_opt)
# Assert they produce the same output
assert numpy.allclose(f(),
numpy.dot(v1.get_value(), m.get_value()) + v2.get_value())
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
assert topo[-1].op.inplace==False
# Assert they produce the same output
assert numpy.allclose(f(),
numpy.dot(v1.get_value(), m.get_value()) + v2.get_value())
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
assert topo[-1].op.inplace==False
#test the inplace version
f = theano.function([], [], updates={v2:v2+theano.dot(v1,m)}
, mode = mode_blas_opt)
#test the inplace version
f = theano.function([], [], updates={v2:v2+theano.dot(v1,m)}
, mode = mode_blas_opt)
# Assert they produce the same output
f()
assert numpy.allclose(v2.get_value(),
numpy.dot(v1.get_value(), m.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
if config.mode != 'FAST_COMPILE':
assert topo[-1].op.inplace==True
# Assert they produce the same output
f()
assert numpy.allclose(v2.get_value(),
numpy.dot(v1.get_value(), m.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
if config.mode != 'FAST_COMPILE':
assert topo[-1].op.inplace==True
class TestGemv(TestCase):
def test_gemv_dimensions(self):
A = T.matrix('A')
x, y = T.vectors('x', 'y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论