提交 53756433 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

modified local_subtensor_make_vector to support advanced indexing

上级 6a6e9914
...@@ -28,7 +28,8 @@ from theano.tensor.elemwise import Elemwise, DimShuffle ...@@ -28,7 +28,8 @@ from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor) AdvancedIncSubtensor,
AdvancedSubtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1330,16 +1331,24 @@ def local_track_shape_i(node): ...@@ -1330,16 +1331,24 @@ def local_track_shape_i(node):
@register_specialize @register_specialize
@register_canonicalize('fast_compile_gpu') @register_canonicalize('fast_compile_gpu')
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
# replace all subtensor(make_vector) like: """
# [a,b,c][0] -> a replace all subtensor(make_vector) like:
# [a,b,c][0:2] -> [a,b] [a,b,c][0] -> a
# we can do this for constant indexes [a,b,c][0:2] -> [a,b]
replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes
"""
x = node.inputs[0]
if not x.owner or x.owner.op != make_vector:
return
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
x = node.inputs[0]
if x.owner and x.owner.op == make_vector:
try: try:
idx, = node.op.idx_list idx, = node.op.idx_list
except Exception: except Exception:
...@@ -1351,19 +1360,26 @@ def local_subtensor_make_vector(node): ...@@ -1351,19 +1360,26 @@ def local_subtensor_make_vector(node):
# is contained in node.inputs[1] # is contained in node.inputs[1]
old_idx, idx = idx, node.inputs[1] old_idx, idx = idx, node.inputs[1]
assert idx.type == old_idx assert idx.type == old_idx
elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1]
else:
return
if isinstance(idx, (int, numpy.integer)): if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]] return [x.owner.inputs[idx]]
elif isinstance(idx, Variable): elif isinstance(idx, Variable):
# if it is a constant we can do something with it # if it is a constant we can do something with it
try: if isinstance(idx, T.Constant):
v = get_scalar_constant_value(idx) # make sure we have an ndarray to access the `ndim` attribute
if isinstance(v, numpy.integer): idx = numpy.asarray(idx.value)
if idx.ndim == 0:
# Python 2.4 wants to index only with Python integers # Python 2.4 wants to index only with Python integers
v = int(v) return [x.owner.inputs[int(idx)]]
return [x.owner.inputs[v]] elif idx.ndim == 1:
except NotScalarConstantError: values = map(int, list(idx))
pass return [make_vector(*[x.owner.inputs[v] for v in values])]
else:
raise TypeError
else: else:
# it is a slice of ints and/or Variables # it is a slice of ints and/or Variables
#TODO: check subtensor to see if it can contain #TODO: check subtensor to see if it can contain
...@@ -1377,6 +1393,7 @@ def local_subtensor_make_vector(node): ...@@ -1377,6 +1393,7 @@ def local_subtensor_make_vector(node):
_logger.error('failed to index with "%s"' % str(idx)) _logger.error('failed to index with "%s"' % str(idx))
raise raise
#TODO: the other optimization for and, or, xor, le and ge see ticket #496. #TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_canonicalize('fast_compile') @register_canonicalize('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论