提交 53756433 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

modified local_subtensor_make_vector to support advanced indexing

上级 6a6e9914
...@@ -28,7 +28,8 @@ from theano.tensor.elemwise import Elemwise, DimShuffle ...@@ -28,7 +28,8 @@ from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor) AdvancedIncSubtensor,
AdvancedSubtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1330,52 +1331,68 @@ def local_track_shape_i(node): ...@@ -1330,52 +1331,68 @@ def local_track_shape_i(node):
@register_specialize @register_specialize
@register_canonicalize('fast_compile_gpu') @register_canonicalize('fast_compile_gpu')
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
# replace all subtensor(make_vector) like: """
# [a,b,c][0] -> a replace all subtensor(make_vector) like:
# [a,b,c][0:2] -> [a,b] [a,b,c][0] -> a
# we can do this for constant indexes [a,b,c][0:2] -> [a,b]
replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes
"""
x = node.inputs[0]
if not x.owner or x.owner.op != make_vector:
return
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
x = node.inputs[0] try:
if x.owner and x.owner.op == make_vector: idx, = node.op.idx_list
try: except Exception:
idx, = node.op.idx_list #'how can you have multiple indexes into a shape?'
except Exception: raise
#'how can you have multiple indexes into a shape?'
raise if isinstance(idx, (scalar.Scalar, T.TensorType)):
# The idx is a Scalar, ie a Type. This means the actual index
if isinstance(idx, (scalar.Scalar, T.TensorType)): # is contained in node.inputs[1]
# The idx is a Scalar, ie a Type. This means the actual index old_idx, idx = idx, node.inputs[1]
# is contained in node.inputs[1] assert idx.type == old_idx
old_idx, idx = idx, node.inputs[1] elif isinstance(node.op, AdvancedSubtensor1):
assert idx.type == old_idx idx = node.inputs[1]
else:
if isinstance(idx, (int, numpy.integer)): return
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable): if isinstance(idx, (int, numpy.integer)):
# if it is a constant we can do something with it return [x.owner.inputs[idx]]
try: elif isinstance(idx, Variable):
v = get_scalar_constant_value(idx) # if it is a constant we can do something with it
if isinstance(v, numpy.integer): if isinstance(idx, T.Constant):
# Python 2.4 wants to index only with Python integers # make sure we have an ndarray to access the `ndim` attribute
v = int(v) idx = numpy.asarray(idx.value)
return [x.owner.inputs[v]] if idx.ndim == 0:
except NotScalarConstantError: # Python 2.4 wants to index only with Python integers
pass return [x.owner.inputs[int(idx)]]
elif idx.ndim == 1:
values = map(int, list(idx))
return [make_vector(*[x.owner.inputs[v] for v in values])]
else: else:
# it is a slice of ints and/or Variables raise TypeError
#TODO: check subtensor to see if it can contain else:
# constant variables, and if it can, then try to # it is a slice of ints and/or Variables
# unpack them. #TODO: check subtensor to see if it can contain
try: # constant variables, and if it can, then try to
return [make_vector(*x.owner.inputs.__getitem__(idx))] # unpack them.
except TypeError: try:
pass return [make_vector(*x.owner.inputs.__getitem__(idx))]
except Exception: except TypeError:
_logger.error('failed to index with "%s"' % str(idx)) pass
raise except Exception:
_logger.error('failed to index with "%s"' % str(idx))
raise
#TODO: the other optimization for and, or, xor, le and ge see ticket #496. #TODO: the other optimization for and, or, xor, le and ge see ticket #496.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论