提交 d6cd1459 authored 作者: Paul Christiano's avatar Paul Christiano

Preliminary version of dot+subtensor optimization, supporting arbitrary indexing.

上级 098261ed
......@@ -2193,13 +2193,25 @@ def local_subtensor_of_dot(node):
# We don't want to compute twice the sub part.
if len(node.inputs[0].clients) > 1:
return
# Do we index/slice on the outer dimensions only?
if node.inputs[0].ndim >= 1 and len(node.op.idx_list) == 1:
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
a_sub = node.op.make_node(a, *node.inputs[1:]).outputs[0]
return [T.dot(a_sub, b)]
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
num_a_indices = min(a.ndim - 1, len(node.op.idx_list))
a_indices = node.op.idx_list[:num_a_indices]
b_indices = (slice(None, None, None),) + node.op.idx_list[num_a_indices:]
num_a_inputs = theano.tensor.subtensor.get_idx_list(node.inputs,
a_indices,
get_count=True)
a_inputs = node.inputs[1:1+num_a_inputs]
b_inputs = node.inputs[1+num_a_inputs:]
import pdb; pdb.set_trace()
a_sub = Subtensor(a_indices).make_node(a, *a_inputs)
b_sub = Subtensor(b_indices).make_node(b, *b_inputs)
return [T.dot(a_sub, b_sub)]
@register_canonicalize
......
......@@ -65,14 +65,20 @@ def make_constant(args):
return tuple(map(conv, args))
def get_idx_list(inputs, idx_list):
def get_idx_list(inputs, idx_list, get_count=False):
'''
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values
the inputs according to the idx list to get the right values.
If get_counts=True, instead returns the number of inputs consumed
during this process.
'''
# The number of indices
n = len(inputs) - 1
# The subtensor (or idx_list) does not depend on the inputs.
if len(inputs) == 1:
if n == 0:
return tuple(idx_list)
indices = list(reversed(list(inputs[1:])))
......@@ -87,7 +93,10 @@ def get_idx_list(inputs, idx_list):
else:
return entry
cdata = tuple(map(convert, idx_list))
return cdata
if get_count:
return n - len(indices)
else:
return cdata
def get_canonical_form_slice(theslice, length):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论