提交 a2ea8184 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix KnownFailure relative to gemm (and partially solves #415).

This does not change the logic of the Gemm optimizer, so #415 might still be needed in the future. This change: - apply the following optimization whether or not the transpose is inplace: T.dot(X,Y).T -> T.dot(Y.T, X.T). The latter form is expected by the Gemm optimizer. - removes the Warning class raised in test_blas.py:just_gemm, because it was caught silently, hiding the fact that another case was failing. A "Failure" is now thrown instead; - adds another test case for transposition; - makes the first case in test_blas.py:test_inplace0 expect a "gemm_no_inplace" instead of _dot22.
上级 7c854bef
......@@ -2824,10 +2824,21 @@ register_canonicalize(constant_folding, 'fast_compile')
register_stabilize(constant_folding) # because
register_specialize(constant_folding)
## dot(x,y).T -> dot(y.T, x.T)
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x')))
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care
# of in a later optimization phase.
# First optimization: inplace
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace')
# Second optimization: not inplace
local_transposed_dot = gof.PatternSub(
(matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
def _is_1(expr):
......
......@@ -309,9 +309,6 @@ def XYZab():
class Failure(Exception):
pass
class Warning(Exception):
pass
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
try:
f = inplace_func(
......@@ -320,14 +317,18 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
mode='FAST_RUN')
at_least_one_gemm = False
for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot not changed to gemm_inplace in graph')
if node.op == _dot22: raise Warning('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace: at_least_one_gemm = True
if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace:
at_least_one_gemm = True
assert at_least_one_gemm
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True)
for node in g.maker.env.nodes:
if node.op == gemm_inplace: raise Exception('gemm_inplace in original graph')
if node.op == gemm_inplace:
raise Exception('gemm_inplace in original graph')
graphlen = len(f.maker.env.toposort())
if max_graphlen and (graphlen <= max_graphlen):
......@@ -345,11 +346,6 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
for node in f.maker.env.toposort():
print 'GRAPH', node
raise
except Warning, e:
#for node in f.maker.env.toposort():
# print 'GRAPH', node
print 'WARNING:', e
#traceback.print_exc()
def test_gemm_opt0():
......@@ -366,6 +362,8 @@ def test_gemm_opt0():
#with transposes (transposes should be pushed through dot in canonicalize)
just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y).T],
ishapes=[(5,3), (3,4), (4,5), (), ()])
#with N multiplications instead of just one
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
......@@ -390,7 +388,8 @@ def test_gemm_opt_double_gemm():
i = [X,Y,Z,a,b, R, S, c]
o = [a * T.dot(X,Y) + gemm_inplace(Z, b, S.T, R.T, 1.0)]
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
f = inplace_func([Param(ii, mutable=True) for ii in i],o,
mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Failure('dot in graph')
if node.op == _dot22: raise Failure('_dot22 in graph')
......@@ -547,15 +546,15 @@ def test_inplace0():
if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0])
raise Failure('gemm_inplace in graph')
assert _dot22 in [n.op for n in f.maker.env.nodes]
assert gemm_no_inplace in [n.op for n in f.maker.env.nodes]
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
# gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)],
mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0])
#raise Failure('no gemm_inplace in graph')
raise KnownFailureTest("gemm not always inserted, see #415")
theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph')
def test_inplace1():
X,Y,Z,a,b = XYZab()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论