提交 4ed23539 authored 作者: Frederic's avatar Frederic

code cleanup.

上级 a7d2157d
...@@ -1906,8 +1906,9 @@ def local_dot22_to_dot22scalar(node): ...@@ -1906,8 +1906,9 @@ def local_dot22_to_dot22scalar(node):
d = node.inputs[dot22_idx] d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar): if not any(i_scalar):
# Check witch input of node is a mul with scalar inputs. # Check if we can reorder the graph as this mul have a mul in inputs.
# We could reuse this scalar value. # We support only 1 additional level of mul.
# The canonizer should have merged those mul together.
i_mul = [x.owner and x.owner.op == T.mul and i_mul = [x.owner and x.owner.op == T.mul and
any([_as_scalar(x_i, dtype=d.dtype) any([_as_scalar(x_i, dtype=d.dtype)
for x_i in x.owner.inputs]) for x_i in x.owner.inputs])
...@@ -1918,52 +1919,40 @@ def local_dot22_to_dot22scalar(node): ...@@ -1918,52 +1919,40 @@ def local_dot22_to_dot22scalar(node):
#by the associativity of the graph. #by the associativity of the graph.
return False return False
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
mul_idx = i_mul.index(True) # The first one should always work mul_idx = i_mul.index(True) # The first one should always work
m = node.inputs[mul_idx] m = node.inputs[mul_idx]
if any([_as_scalar(x, dtype=d.dtype) scalar_idx = -1
for x in m.owner.inputs]): # This should be always True for i, x in enumerate(m.owner.inputs):
scalar_idx = -1 if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
for i, x in enumerate(m.owner.inputs): x.type.dtype, d.type.dtype)
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast( == d.type.dtype):
x.type.dtype, d.type.dtype) scalar_idx = i
== d.type.dtype): break
scalar_idx = i
break if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the'
if scalar_idx < 0: ' type of the scalar cannot be upcasted to the'
_logger.info('Not optimizing dot22 with inputs %s %s, as the' ' matrix type',
' type of the scalar cannot be upcasted to the' node.inputs, [x.type for x in node.inputs])
' matrix type',
node.inputs, [x.type for x in node.inputs])
return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx],
dtype=d.dtype), d.type.dtype)
assert not a.type.ndim
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# The other inputs to the original node that were
# neither part of the dot22 or this mul should be
# factors in the returned "mul" node.
assert dot22_idx != mul_idx
other_factors = [inpt
for i, inpt in enumerate(node.inputs)
if i not in (dot22_idx, mul_idx)]
other_m_inputs = [inpt
for i, inpt in enumerate(m.owner.inputs)
if i != scalar_idx]
return [T.mul(dot, *(other_factors + other_m_inputs))]
elif m.owner and m.owner.op == T.mul:
_logger.info('Not optimizing dot22 with inputs %s %s %s %s. '
'we need to check in a recursive way in the mul if we can '
'reorder the graph. The canonizer should have done this.',
d, m, d.type, m.type)
else:
return False return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx],
dtype=d.dtype), d.type.dtype)
assert not a.type.ndim
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# The other inputs to the original node that were
# neither part of the dot22 or this mul should be
# factors in the returned "mul" node.
assert dot22_idx != mul_idx
other_factors = [inpt
for i, inpt in enumerate(node.inputs)
if i not in (dot22_idx, mul_idx)]
other_m_inputs = [inpt
for i, inpt in enumerate(m.owner.inputs)
if i != scalar_idx]
return [T.mul(dot, *(other_factors + other_m_inputs))]
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(node.inputs): for i, x in enumerate(node.inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论