提交 79ac284c authored 作者: Paul Christiano's avatar Paul Christiano

Cleaner implementation of local_subtensor_of_dot.

上级 4ba9d2f4
......@@ -2199,9 +2199,11 @@ def local_subtensor_of_dot(node):
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
num_a_indices = min(a.ndim - 1, len(node.op.idx_list))
a_indices = node.op.idx_list[:num_a_indices]
b_indices = node.op.idx_list[num_a_indices:]
idx_list = theano.tensor.subtensor.get_idx_list(node.inputs, node.op.idx_list)
num_a_indices = min(a.ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
......@@ -2213,14 +2215,17 @@ def local_subtensor_of_dot(node):
# This determines how many of the inputs need to be used to index into a.
# The remaining inputs are used to index into b.
num_a_inputs = theano.tensor.subtensor.get_idx_list(node.inputs,
a_indices,
get_count=True)
a_inputs = node.inputs[1:1+num_a_inputs]
b_inputs = node.inputs[1+num_a_inputs:]
a_sub = Subtensor(a_indices).make_node(a, *a_inputs)
b_sub = Subtensor(b_indices).make_node(b, *b_inputs) if b_indices else b
#num_a_inputs = theano.tensor.subtensor.get_idx_list(node.inputs,
# a_indices,
# get_count=True)
#a_inputs = node.inputs[1:1+num_a_inputs]
#b_inputs = node.inputs[1+num_a_inputs:]
#a_sub = Subtensor(a_indices).make_node(a, *a_inputs)
#b_sub = Subtensor(b_indices).make_node(b, *b_inputs) if b_indices else b
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
return [T.dot(a_sub, b_sub)]
......
......@@ -2506,9 +2506,11 @@ def test_local_subtensor_of_dot():
d2 = numpy.arange(72).reshape(4,3,6).astype(config.floatX) + 100
f = theano.function([m1, m2, idx], theano.dot(m1, m2)[idx,1:4,:,idx:], mode=mode)
assert test_equality(f(d1, d2, 1), numpy.dot(d1, d2)[1,1:4,:,1:])
f = theano.function([m1, m2, idx], theano.dot(m1, m2)[1:4,:,idx:,idx], mode=mode)
assert test_equality(f(d1, d2, 1), numpy.dot(d1, d2)[1:4,:,1:,1])
def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论