提交 3c77754e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Replace calls to 'gemm' by 'gemm_inplace' in tests.

上级 215bdcd2
......@@ -35,7 +35,7 @@ class t_gemm(TestCase):
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor_variable(p).type() for p in z,a,x,y,b]
f = inplace_func([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
......@@ -55,7 +55,7 @@ class t_gemm(TestCase):
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
g = gemm_inplace([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
......@@ -100,7 +100,7 @@ class t_gemm(TestCase):
"""test that only first input can be overwritten"""
Z = as_tensor_variable(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
gemm_inplace(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
......@@ -110,7 +110,7 @@ class t_gemm(TestCase):
Z = as_tensor_variable(self.rand(2,2))
A = as_tensor_variable(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
......@@ -120,7 +120,7 @@ class t_gemm(TestCase):
Z = as_tensor_variable(self.rand(2,2))
A = as_tensor_variable(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
......@@ -130,7 +130,7 @@ class t_gemm(TestCase):
Z = as_tensor_variable(self.rand(2,2))
A = as_tensor_variable(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
gemm_inplace(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
......@@ -140,9 +140,9 @@ class t_gemm(TestCase):
"""test that dot args can be aliased"""
Z = shared(self.rand(2,2))
A = shared(self.rand(2,2))
f = inplace_func([], gemm(Z, 1.0, A, A, 1.0))
f = inplace_func([], gemm_inplace(Z, 1.0, A, A, 1.0))
f()
f = inplace_func([], gemm(Z, 1.0, A, A.T, 1.0))
f = inplace_func([], gemm_inplace(Z, 1.0, A, A.T, 1.0))
f()
def test_transposes(self):
......@@ -158,9 +158,9 @@ class t_gemm(TestCase):
tz,ta,tx,ty,tb = [shared(p) for p in z,a,x,y,b]
#f = inplace_func([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
#f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
#f(z, a, x, y, b)
f = inplace_func([], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f = inplace_func([], gemm_inplace(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f()
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f()
......@@ -271,11 +271,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot not changed to gemm in graph')
if node.op == _dot22: raise Warning('_dot22 not changed to gemm in graph')
if node.op == T.dot: raise Warning('dot not changed to gemm_inplace in graph')
if node.op == _dot22: raise Warning('_dot22 not changed to gemm_inplace in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes:
if node.op == gemm: raise Exception('gemm in original graph')
if node.op == gemm_inplace: raise Exception('gemm_inplace in original graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
......@@ -331,7 +331,7 @@ def test_gemm_opt_double_gemm():
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c]
o = [a * T.dot(X,Y) + gemm(Z, b, S.T, R.T, 1.0)]
o = [a * T.dot(X,Y) + gemm_inplace(Z, b, S.T, R.T, 1.0)]
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
......@@ -339,7 +339,7 @@ def test_gemm_opt_double_gemm():
if node.op == _dot22: raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None))
#for node in g.maker.env.nodes:
# if node.op == gemm: raise Failure('gemm in graph')
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
......@@ -393,38 +393,39 @@ def test_gemm_opt_vector_stuff():
u,v = T.dvector(), T.dvector()
f = inplace_func([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
if gemm in [n.op for n in f.maker.env.nodes]:
raise Failure('gemm in graph')
if gemm_inplace in [n.op for n in f.maker.env.nodes]:
raise Failure('gemm_inplace in graph')
f = inplace_func([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm_inplace in graph')
def test_inplace0():
#should fail to insert gemm because gemm would create cycles
#should fail to insert gemm_inplace because gemm_inplace would create cycles
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R, S, c = T.dmatrix('R'), T.dmatrix('S'), T.dscalar('c')
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0])
raise Failure('gemm in graph')
raise Failure('gemm_inplace in graph')
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
# gemm should be insertedd here, to work in-place on Z*c
if (not gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm in graph')
# gemm_inplace should be insertedd here, to work in-place on Z*c
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0])
raise Failure('no gemm_inplace in graph')
def test_inplace1():
X,Y,Z,a,b = XYZab()
# with > 2 terms in the overall addition
f = inplace_func([X,Y,Z,a,b],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN')
# gemm should operate in-place on (Z+Z)
if (not gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm in graph')
# gemm_inplace should operate in-place on (Z+Z)
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm_inplace in graph')
def test_dot22():
if config.mode == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论