提交 53756433 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

modified local_subtensor_make_vector to support advanced indexing

上级 6a6e9914
......@@ -28,7 +28,8 @@ from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1,
AdvancedIncSubtensor)
AdvancedIncSubtensor,
AdvancedSubtensor1)
from theano import scalar
from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file
......@@ -1330,16 +1331,24 @@ def local_track_shape_i(node):
@register_specialize
@register_canonicalize('fast_compile_gpu')
@gof.local_optimizer([Subtensor])
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node):
# replace all subtensor(make_vector) like:
# [a,b,c][0] -> a
# [a,b,c][0:2] -> [a,b]
# we can do this for constant indexes
"""
replace all subtensor(make_vector) like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes
"""
x = node.inputs[0]
if not x.owner or x.owner.op != make_vector:
return
if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature
x = node.inputs[0]
if x.owner and x.owner.op == make_vector:
try:
idx, = node.op.idx_list
except Exception:
......@@ -1351,19 +1360,26 @@ def local_subtensor_make_vector(node):
# is contained in node.inputs[1]
old_idx, idx = idx, node.inputs[1]
assert idx.type == old_idx
elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1]
else:
return
if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
# if it is a constant we can do something with it
try:
v = get_scalar_constant_value(idx)
if isinstance(v, numpy.integer):
if isinstance(idx, T.Constant):
# make sure we have an ndarray to access the `ndim` attribute
idx = numpy.asarray(idx.value)
if idx.ndim == 0:
# Python 2.4 wants to index only with Python integers
v = int(v)
return [x.owner.inputs[v]]
except NotScalarConstantError:
pass
return [x.owner.inputs[int(idx)]]
elif idx.ndim == 1:
values = map(int, list(idx))
return [make_vector(*[x.owner.inputs[v] for v in values])]
else:
raise TypeError
else:
# it is a slice of ints and/or Variables
#TODO: check subtensor to see if it can contain
......@@ -1377,6 +1393,7 @@ def local_subtensor_make_vector(node):
_logger.error('failed to index with "%s"' % str(idx))
raise
#TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_canonicalize('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论