提交 d2406c02 authored 作者: James Bergstra's avatar James Bergstra

A few changes to AdvancedSubtensor

上级 0b59bab3
......@@ -2982,6 +2982,8 @@ class AdvancedSubtensor(Op):
#TODO: see what's the best solution
self.args = args #?
#FIXME: do not store variables in the class instance
#FIXME
#if len(args) != 2:
# print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args)
......@@ -2993,6 +2995,11 @@ class AdvancedSubtensor(Op):
if x.ndim == 2 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if not (ind1.type.dtype.startswith('int') or ind1.type.dtype.startswith('uint')):
raise TypeError()
if not (ind2.type.dtype.startswith('int') or ind2.type.dtype.startswith('uint')):
raise TypeError()
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
(x,) + inputs,
......@@ -3004,7 +3011,11 @@ class AdvancedSubtensor(Op):
% ','.join(str(input) for input in inputs))
def perform(self, node, inputs, (out,)):
pass
# TODO: in general, we need to re-pack the inputs into a valid index, just like
# subtensor
out[0] = inputs[0].__getitem__(inputs[1:])
#return
#raise NotImplementedError()
def grad(self, inputs, (gz,)):
x = inputs[0]
......@@ -3035,10 +3046,13 @@ class AdvancedIncSubtensor(Op):
raise NotImplementedError('Advanced indexing increment of x by y with arguments (%s) not supported yet'\
% ','.join(str(input) for input in inputs))
def perform(self, node, inputs, (out,)):
pass
#def perform(self, node, inputs, (out,)):
#raise NotImplementedError()
#def grad?
# grad on x is grad on output
# grad on y is grad_output[idx_list]
# grad on rest is None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论