提交 a7d2157d authored 作者: Frederic's avatar Frederic

Make local_dot22_to_dot22scalar opt work in one more case.

上级 54f611e9
......@@ -1906,7 +1906,12 @@ def local_dot22_to_dot22scalar(node):
d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar):
i_mul = [x.owner and x.owner.op == T.mul for x in node.inputs]
# Check witch input of node is a mul with scalar inputs.
# We could reuse this scalar value.
i_mul = [x.owner and x.owner.op == T.mul and
any([_as_scalar(x_i, dtype=d.dtype)
for x_i in x.owner.inputs])
for x in node.inputs]
if not any(i_mul):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph
......@@ -1916,11 +1921,11 @@ def local_dot22_to_dot22scalar(node):
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
mul_idx = i_mul.index(True) # we take the first mul!
mul_idx = i_mul.index(True) # The first one should always work
m = node.inputs[mul_idx]
if any([_as_scalar(x, dtype=d.dtype)
for x in m.owner.inputs]):
for x in m.owner.inputs]): # This should be always True
scalar_idx = -1
for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
......
......@@ -1067,8 +1067,8 @@ def test_local_dot22_to_dot22scalar():
T.mul(_dot22(A, A), T.mul(m, y, z), m),
T.mul(_dot22(A, A), m, T.mul(m, y, z)),
#Case that isn't opt for now.
#T.mul(_dot22(A, A), (r * m), (m * x)),
#Case that opt later in gh-1515
T.mul(_dot22(A, A), (r * m), (m * x)),
]):
node2 = theano.tensor.blas.local_dot22_to_dot22scalar.transform(
node.owner)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论