提交 4ba9d2f4 authored 作者: Paul Christiano's avatar Paul Christiano

Tests for subtensor_of_dot, and corrected optimization for tensors with more than 2 dimensions.

上级 ea704d6b
......@@ -2181,8 +2181,10 @@ def local_subtensor_of_alloc(node):
@gof.local_optimizer([Subtensor])
def local_subtensor_of_dot(node):
"""
This optimization translates T.dot(A, B)[xs+ys] into T.dot(A[xs,:], B[:,ys]).
This optimization translates T.dot(A, B)[idxs] into T.dot(A[idxs_a], B[idxs_b]).
idxs_a is the first A.ndim-1 entries of idxs
idxs_b is the remaining entries of idxs (if any),
but with : inserted in the second-to-last dimension (because dot sums over this dimension)
"""
if not isinstance(node.op, Subtensor):
return
......@@ -2199,18 +2201,26 @@ def local_subtensor_of_dot(node):
num_a_indices = min(a.ndim - 1, len(node.op.idx_list))
a_indices = node.op.idx_list[:num_a_indices]
b_indices = (slice(None, None, None),) + node.op.idx_list[num_a_indices:]
b_indices = node.op.idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just ommitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = b_indices[:b.ndim-2] + (slice(None, None, None),) + b_indices[b.ndim-2:]
# This determines how many of the inputs need to be used to index into a.
# The remaining inputs are used to index into b.
num_a_inputs = theano.tensor.subtensor.get_idx_list(node.inputs,
a_indices,
get_count=True)
a_inputs = node.inputs[1:1+num_a_inputs]
b_inputs = node.inputs[1+num_a_inputs:]
import pdb; pdb.set_trace()
a_sub = Subtensor(a_indices).make_node(a, *a_inputs)
b_sub = b if len(b_indices) == 1 else Subtensor(b_indices).make_node(b, *b_inputs)
b_sub = Subtensor(b_indices).make_node(b, *b_inputs) if b_indices else b
return [T.dot(a_sub, b_sub)]
......
......@@ -2483,21 +2483,31 @@ def test_local_subtensor_of_dot():
d2 = numpy.arange(8).reshape((2, 4)).astype(config.floatX) + 10
mode = compile.get_default_mode().including("local_subtensor_of_dot")
def test_equality(a, b):
return a.shape == b.shape and numpy.allclose(a, b)
# [cst]
f = theano.function([m1, m2], theano.dot(m1, m2)[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert numpy.allclose(f(d1, d2), numpy.dot(d1, d2)[1])
assert test_equality(f(d1, d2), numpy.dot(d1, d2)[1])
# DimShuffle happen in FAST_COMPILE
assert isinstance(topo[-1].op, (T.blas_c.CGemv, T.blas.Gemv, T.DimShuffle))
# slice
f = theano.function([m1, m2], theano.dot(m1, m2)[1:2], mode=mode)
topo = f.maker.fgraph.toposort()
assert numpy.allclose(f(d1, d2), numpy.dot(d1, d2)[1:2])
assert test_equality(f(d1, d2), numpy.dot(d1, d2)[1:2])
assert isinstance(topo[-1].op, (T.blas.Dot22))
# TODO: tests with vector and tensor3d.
# TODO: new opt for AdvancedSubtensor1
m1 = theano.tensor.tensor3()
m2 = theano.tensor.tensor3()
idx = theano.tensor.iscalar()
d1 = numpy.arange(30).reshape(2,5,3).astype(config.floatX)
d2 = numpy.arange(72).reshape(4,3,6).astype(config.floatX) + 100
f = theano.function([m1, m2, idx], theano.dot(m1, m2)[idx,1:4,:,idx:], mode=mode)
assert test_equality(f(d1, d2, 1), numpy.dot(d1, d2)[1,1:4,:,1:])
def test_local_subtensor_of_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论