提交 85741b9e authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2161 from paulfchristiano/dot_optimization

Dot optimization
......@@ -2175,6 +2175,53 @@ def local_subtensor_of_alloc(node):
return rval
@register_canonicalize
@register_stabilize
@register_specialize
@gof.local_optimizer([Subtensor])
def local_subtensor_of_dot(node):
"""
This optimization translates T.dot(A, B)[idxs] into T.dot(A[idxs_a], B[idxs_b]),
where idxs_a and idxs_b are defined appropriately.
idxs_a is the first A.ndim-1 entries of idxs,
and idxs_b is the remaining entries of idxs (if any),
modified to skip the second-to-last dimension of B
(because dot sums over this dimension)
"""
if not isinstance(node.op, Subtensor):
return
if (not node.inputs[0].owner or
not isinstance(node.inputs[0].owner.op, T.Dot)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if len(node.inputs[0].clients) > 1:
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
idx_list = theano.tensor.subtensor.get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just ommitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = b_indices[:b.ndim-2] + (slice(None, None, None),) + b_indices[b.ndim-2:]
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
return [T.dot(a_sub, b_sub)]
@register_canonicalize
@gof.local_optimizer([T.add])
def local_IncSubtensor_serialize(node):
......
......@@ -65,14 +65,20 @@ def make_constant(args):
return tuple(map(conv, args))
def get_idx_list(inputs, idx_list):
def get_idx_list(inputs, idx_list, get_count=False):
'''
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values
the inputs according to the idx list to get the right values.
If get_counts=True, instead returns the number of inputs consumed
during this process.
'''
# The number of indices
n = len(inputs) - 1
# The subtensor (or idx_list) does not depend on the inputs.
if len(inputs) == 1:
if n == 0:
return tuple(idx_list)
indices = list(reversed(list(inputs[1:])))
......@@ -87,7 +93,10 @@ def get_idx_list(inputs, idx_list):
else:
return entry
cdata = tuple(map(convert, idx_list))
return cdata
if get_count:
return n - len(indices)
else:
return cdata
def get_canonical_form_slice(theslice, length):
......
......@@ -2477,6 +2477,41 @@ class Test_alloc_zero(unittest.TestCase):
_e1[2], _e2[1])
def test_local_subtensor_of_dot():
m1 = theano.tensor.matrix()
m2 = theano.tensor.matrix()
d1 = numpy.arange(6).reshape((3, 2)).astype(config.floatX)
d2 = numpy.arange(8).reshape((2, 4)).astype(config.floatX) + 10
mode = compile.get_default_mode().including("local_subtensor_of_dot")
def test_equality(a, b):
return a.shape == b.shape and numpy.allclose(a, b)
# [cst]
f = theano.function([m1, m2], theano.dot(m1, m2)[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert test_equality(f(d1, d2), numpy.dot(d1, d2)[1])
# DimShuffle happen in FAST_COMPILE
assert isinstance(topo[-1].op, (T.blas_c.CGemv, T.blas.Gemv, T.DimShuffle))
# slice
f = theano.function([m1, m2], theano.dot(m1, m2)[1:2], mode=mode)
topo = f.maker.fgraph.toposort()
assert test_equality(f(d1, d2), numpy.dot(d1, d2)[1:2])
assert isinstance(topo[-1].op, (T.blas.Dot22))
m1 = theano.tensor.tensor3()
m2 = theano.tensor.tensor3()
idx = theano.tensor.iscalar()
d1 = numpy.arange(30).reshape(2,5,3).astype(config.floatX)
d2 = numpy.arange(72).reshape(4,3,6).astype(config.floatX) + 100
f = theano.function([m1, m2, idx], theano.dot(m1, m2)[idx,1:4,:,idx:], mode=mode)
assert test_equality(f(d1, d2, 1), numpy.dot(d1, d2)[1,1:4,:,1:])
f = theano.function([m1, m2, idx], theano.dot(m1, m2)[1:4,:,idx:,idx], mode=mode)
assert test_equality(f(d1, d2, 1), numpy.dot(d1, d2)[1:4,:,1:,1])
def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论