提交 41a7e608 authored 作者: Frederic's avatar Frederic

Refactor condition and use isinstance

上级 f4f68658
......@@ -1190,32 +1190,31 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
if M.owner and isinstance(M.owner.op, T.DimShuffle):
if (M.owner and isinstance(M.owner.op, T.DimShuffle) and
M.owner.inputs[0].owner and
isinstance(M.owner.inputs[0].owner.op, Dot22)):
MM = M.owner.inputs[0]
if tuple(M.owner.op.new_order) == (0,):
if M.owner.op.new_order == (0,):
# it is making a column MM into a vector
if MM.owner and MM.owner.op == _dot22:
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle(0, 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)]
return rval, MM
if tuple(M.owner.op.new_order) == (1,):
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle(0, 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)]
return rval, MM
if M.owner.op.new_order == (1,):
# it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22:
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle('x', 0),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)]
return rval, MM
if tuple(M.owner.op.new_order) == ():
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle('x', 0),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)]
return rval, MM
if len(M.owner.op.new_order) == 0:
# it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22:
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle('x', 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle()]
return rval, MM
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle('x', 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle()]
return rval, MM
# this is False'd out because of inadequate testing.
# TODO see ticket #237
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论