提交 b2b7bbc8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Generalize optimization to more inputs in mul

上级 61a7aac9
......@@ -1919,8 +1919,8 @@ def local_dot22_to_dot22scalar(node):
mul_idx = i_mul.index(True) # we take the first mul!
m = node.inputs[mul_idx]
if len(m.owner.inputs) == 2 and any([_as_scalar(x, dtype=d.dtype)
for x in m.owner.inputs]):
if any([_as_scalar(x, dtype=d.dtype)
for x in m.owner.inputs]):
scalar_idx = -1
for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
......@@ -1947,8 +1947,11 @@ def local_dot22_to_dot22scalar(node):
other_factors = [inpt
for i, inpt in enumerate(node.inputs)
if i not in (dot22_idx, mul_idx)]
other_m_inputs = [inpt
for i, inpt in enumerate(m.owner.inputs)
if i != scalar_idx]
return [T.mul(m.owner.inputs[1 - i], dot, *other_factors)]
return [T.mul(dot, *(other_factors + other_m_inputs))]
elif m.owner and m.owner.op == T.mul:
_logger.info('Not optimizing dot22 with inputs %s %s %s %s. '
'we need to check in a recursive way in the mul if we can '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论