提交 41a7e608 authored 作者: Frederic's avatar Frederic

Refactor condition and use isinstance

上级 f4f68658
...@@ -1190,32 +1190,31 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1190,32 +1190,31 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
# it also might be the case that there is a dimshuffle between the + # it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things. # and the dot22. local_dot_to_dot22 in particular will put in such things.
if M.owner and isinstance(M.owner.op, T.DimShuffle): if (M.owner and isinstance(M.owner.op, T.DimShuffle) and
M.owner.inputs[0].owner and
isinstance(M.owner.inputs[0].owner.op, Dot22)):
MM = M.owner.inputs[0] MM = M.owner.inputs[0]
if tuple(M.owner.op.new_order) == (0,): if M.owner.op.new_order == (0,):
# it is making a column MM into a vector # it is making a column MM into a vector
if MM.owner and MM.owner.op == _dot22: MMl, MMr = MM.owner.inputs
MMl, MMr = MM.owner.inputs g = gemm_no_inplace(L.dimshuffle(0, 'x'),
g = gemm_no_inplace(L.dimshuffle(0, 'x'), alpha, MMl, MMr, beta)
alpha, MMl, MMr, beta) rval = [g.dimshuffle(0)]
rval = [g.dimshuffle(0)] return rval, MM
return rval, MM if M.owner.op.new_order == (1,):
if tuple(M.owner.op.new_order) == (1,):
# it is making a row MM into a vector # it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22: MMl, MMr = MM.owner.inputs
MMl, MMr = MM.owner.inputs g = gemm_no_inplace(L.dimshuffle('x', 0),
g = gemm_no_inplace(L.dimshuffle('x', 0), alpha, MMl, MMr, beta)
alpha, MMl, MMr, beta) rval = [g.dimshuffle(1)]
rval = [g.dimshuffle(1)] return rval, MM
return rval, MM if len(M.owner.op.new_order) == 0:
if tuple(M.owner.op.new_order) == ():
# it is making a row MM into a vector # it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22: MMl, MMr = MM.owner.inputs
MMl, MMr = MM.owner.inputs g = gemm_no_inplace(L.dimshuffle('x', 'x'),
g = gemm_no_inplace(L.dimshuffle('x', 'x'), alpha, MMl, MMr, beta)
alpha, MMl, MMr, beta) rval = [g.dimshuffle()]
rval = [g.dimshuffle()] return rval, MM
return rval, MM
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论