提交 09c8c2d3 authored 作者: James Bergstra's avatar James Bergstra

added disabled-by-default experimental advanced indexing code

上级 9563676f
......@@ -4,7 +4,7 @@ __docformat__ = "restructuredtext en"
import __builtin__
import sys # for sys.maxint
from theano.configparser import config
from theano.configparser import config, AddConfigVar, BoolParam
import traceback #for overriding Op.__call__
if sys.version_info >= (2,5):
import functools
......@@ -943,7 +943,13 @@ class _tensor_py_operators:
break
if advanced:
return AdvancedSubtensor(args)(self, *args)
if config.experimental.advanced_indexing:
if len(args) == 1:
return AdvancedSubtensor1()(self, *args)
else:
return AdvancedSubtensor(args)(self, *args)
else:
return AdvancedSubtensor(args)(self, *args)
else:
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable)))
......@@ -3136,6 +3142,37 @@ def inverse_permutation(perm):
# Should reproduce numpy's behaviour:
# http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
AddConfigVar('experimental.advanced_indexing',
"enable not-well-tested advanced indexing functionality",
BoolParam(False))
class AdvancedSubtensor1(Op):
"""Implement x[ilist] where ilist is a vector of integers."""
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
type(self) == type(other)
def make_node(self, x, ilist):
x_ = as_tensor_variable(x)
ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,):
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
if x_.type.broadcastable[0]:
# the caller should have made a copy of x len(ilist) times
raise TypeError('cannot index into a broadcastable dimension')
return gof.Apply(self, [x_, ilist_], [x_.type()])
def perform(self, node, (x,i), (out,)):
out[0] = x[i]
class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing.
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论