提交 9a5e2e97 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3 from nouiz/lamblin-fix_1507

Add tests and opt one more case.
...@@ -1906,21 +1906,22 @@ def local_dot22_to_dot22scalar(node): ...@@ -1906,21 +1906,22 @@ def local_dot22_to_dot22scalar(node):
d = node.inputs[dot22_idx] d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar): if not any(i_scalar):
i_mul = [x.owner and x.owner.op == T.mul for x in node.inputs] # Check if we can reorder the graph as this mul have a mul in inputs.
# We support only 1 additional level of mul.
# The canonizer should have merged those mul together.
i_mul = [x.owner and x.owner.op == T.mul and
any([_as_scalar(x_i, dtype=d.dtype)
for x_i in x.owner.inputs])
for x in node.inputs]
if not any(i_mul): if not any(i_mul):
#no scalar in input and no multiplication #no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph #if their was a multiplication we couls reorder the graph
#by the associativity of the graph. #by the associativity of the graph.
return False return False
#maybe we can reorder the graph as this mul have a mul in input. mul_idx = i_mul.index(True) # The first one should always work
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
mul_idx = i_mul.index(True) # we take the first mul!
m = node.inputs[mul_idx] m = node.inputs[mul_idx]
if any([_as_scalar(x, dtype=d.dtype)
for x in m.owner.inputs]):
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(m.owner.inputs): for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast( if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
...@@ -1952,13 +1953,6 @@ def local_dot22_to_dot22scalar(node): ...@@ -1952,13 +1953,6 @@ def local_dot22_to_dot22scalar(node):
if i != scalar_idx] if i != scalar_idx]
return [T.mul(dot, *(other_factors + other_m_inputs))] return [T.mul(dot, *(other_factors + other_m_inputs))]
elif m.owner and m.owner.op == T.mul:
_logger.info('Not optimizing dot22 with inputs %s %s %s %s. '
'we need to check in a recursive way in the mul if we can '
'reorder the graph. The canonizer should have done this.',
d, m, d.type, m.type)
else:
return False
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(node.inputs): for i, x in enumerate(node.inputs):
......
...@@ -1024,6 +1024,60 @@ def test_dot22scalar_cast(): ...@@ -1024,6 +1024,60 @@ def test_dot22scalar_cast():
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()] assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
def test_local_dot22_to_dot22scalar():
"""
This test that the bug in gh-1507 is really fixed
"""
A = T.dmatrix()
mode = theano.compile.mode.get_default_mode()
opt = theano.tensor.opt.in2out(
theano.tensor.blas.local_dot22_to_dot22scalar)
mode = mode.__class__(optimizer=opt)
x = T.dscalar()
y = T.dscalar()
z = T.dscalar()
# make sure to don't have dimshuffle as we don't opt those cases
m = T.dmatrix()
r = T.drow()
for idx, node in enumerate([
#Old working cases
T.mul(_dot22(A, A), x),
T.mul(_dot22(A, A), x, y),
T.mul(_dot22(A, A), x, r),
T.mul(_dot22(A, A), m, x),
T.mul(_dot22(A, A), x, m),
T.mul(_dot22(A, A), x, (m * y)),
T.mul(_dot22(A, A), (m * y), x),
T.mul(_dot22(A, A), x, (r * y)),
T.mul(_dot22(A, A), (r * y), x),
T.mul(_dot22(A, A), (x * y), (m * x)),
T.mul(_dot22(A, A), (r * y), (y * x)),
# Case that was raising an assert that is fixed in gh-1507
T.mul(_dot22(A, A), (m * y), m),
T.mul(_dot22(A, A), m, (m * y)),
T.mul(_dot22(A, A), (r * y), (m * x)),
# assert fixed in gh-1507 and opt case added in gh-1515
T.mul(_dot22(A, A), (m * y * z), m),
T.mul(_dot22(A, A), m, (m * y * z)),
# Opt case added in gh-1515
T.mul(_dot22(A, A), T.mul(m, y, z), m),
T.mul(_dot22(A, A), m, T.mul(m, y, z)),
#Case that opt later in gh-1515
T.mul(_dot22(A, A), (r * m), (m * x)),
]):
node2 = theano.tensor.blas.local_dot22_to_dot22scalar.transform(
node.owner)
assert node2
f = theano.function([x, y, z, m, r, A], node,
mode=mode, on_unused_input='ignore')
f(.1, .2, .3, [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10]])
def test_dot_w_self(): def test_dot_w_self():
# This can trigger problems in the optimization because what would # This can trigger problems in the optimization because what would
# normally be a gemm must not be because the output is aliased to # normally be a gemm must not be because the output is aliased to
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论