提交 52845167 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update tests following changes in gemm optimizer (allowing vectors)

上级 b3800f6d
...@@ -409,6 +409,8 @@ def test_gemm_canonicalize(): ...@@ -409,6 +409,8 @@ def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u') u = T.row('u')
v = T.vector('v')
w = T.col('w')
can = [] can = []
_gemm_canonicalize(X + Y + Z, 1.0, can, 0) _gemm_canonicalize(X + Y + Z, 1.0, can, 0)
...@@ -416,7 +418,22 @@ def test_gemm_canonicalize(): ...@@ -416,7 +418,22 @@ def test_gemm_canonicalize():
can = [] can = []
_gemm_canonicalize(X + Y + u, 1.0, can, 0) _gemm_canonicalize(X + Y + u, 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), u], can assert can == [(1.0, X), (1.0, Y), (1.0, u)], can
can = []
_gemm_canonicalize(X + Y + v, 1.0, can, 0)
# [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))]
assert can[:2] == [(1.0, X), (1.0, Y)]
assert isinstance(can[2], tuple)
assert len(can[2]) == 2
assert can[2][0] == 1.0
assert can[2][1].owner
assert isinstance(can[2][1].owner.op, T.DimShuffle)
assert can[2][1].owner.inputs == [v]
can = []
_gemm_canonicalize(X + Y + w, 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = [] can = []
_gemm_canonicalize(a*X + Y - b*Z*c, 1.0, can, 0) _gemm_canonicalize(a*X + Y - b*Z*c, 1.0, can, 0)
...@@ -442,16 +459,14 @@ def test_gemm_canonicalize(): ...@@ -442,16 +459,14 @@ def test_gemm_canonicalize():
def test_gemm_factor(): def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
assert [(1.0, X), (1.0, Y), u] == _factor_canonicalized([(1.0, X), (1.0, Y), u]) assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X), u] == _factor_canonicalized([(1.0, X),(1.0, X), u]) assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)])
def test_gemm_nested(): def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
just_gemm([X,Y,Z,R,S,U,a,b,c,d], just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z)], [a * Z - b * (c*T.dot(X,Y) + d*Z)],
...@@ -529,7 +544,7 @@ def test_inplace0(): ...@@ -529,7 +544,7 @@ def test_inplace0():
f = inplace_func([X,Y,Z,a,b, R, S, c], f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN') [Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
# gemm_inplace should be insertedd here, to work in-place on Z*c # gemm_inplace should be inserted here, to work in-place on Z*c
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]): if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0]) print pp(f.maker.env.outputs[0])
raise Failure('no gemm_inplace in graph') raise Failure('no gemm_inplace in graph')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论