提交 4d5e8de0 authored 作者: Frederic's avatar Frederic

Remove the useless Avanced[Inc]Subtensor args and added missing fct.

上级 d79dd27b
......@@ -1377,7 +1377,7 @@ class _tensor_py_operators:
theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args)
else:
return AdvancedSubtensor(args)(self, *args)
return AdvancedSubtensor()(self, *args)
else:
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable)))
......@@ -5272,20 +5272,17 @@ class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing.
"""
# Should be used by __getitem__ and __getslice__, as follow:
# AdvancedSubtensor(args)(self, *args),
# AdvancedSubtensor()(self, *args),
# if args contains and advanced indexing pattern
def __init__(self, args): # idx_list?
# For the moment, __init__ will be passed the whole list of arguments
#TODO: see what's the best solution
self.args = args # ?
def __eq__(self, other):
return self.__class__ == other.__class__
#FIXME: do not store variables in the class instance
def __hash__(self):
return hash(self.__class__)
#FIXME
#if len(args) != 2:
# print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args)
# print >>sys.stderr, ' arguments are:', args
def __str__(self):
return self.__class__.__name__
def make_node(self, x, *inputs):
x = as_tensor_variable(x)
......@@ -5348,14 +5345,21 @@ class AdvancedSubtensor(Op):
gz, = grads
x = inputs[0]
rest = inputs[1:]
return [AdvancedIncSubtensor(self.args)(zeros_like(x), gz, *rest)] + [None]*len(rest)
return [AdvancedIncSubtensor()(zeros_like(x), gz, *rest)] + [None]*len(rest)
class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing.
"""
def __init__(self, args): #idx_list? inplace=False?
self.args = args
def __eq__(self, other):
return self.__class__ == other.__class__
def __hash__(self):
return hash(self.__class__)
def __str__(self):
return self.__class__.__name__
def make_node(self, x, y, *inputs):
x = as_tensor_variable(x)
......@@ -5393,7 +5397,7 @@ class AdvancedIncSubtensor(Op):
idxs = inpt[2:]
outgrad, = output_gradients
d_x_wrt_C = outgrad
d_y_wrt_C = AdvancedSubtensor(self.args)(outgrad, *idxs)
d_y_wrt_C = AdvancedSubtensor()(outgrad, *idxs)
return [d_x_wrt_C, d_y_wrt_C] + [None for _ in idxs]
def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论