提交 4d5e8de0 authored 作者: Frederic's avatar Frederic

Remove the useless Avanced[Inc]Subtensor args and added missing fct.

上级 d79dd27b
...@@ -1377,7 +1377,7 @@ class _tensor_py_operators: ...@@ -1377,7 +1377,7 @@ class _tensor_py_operators:
theano.tensor.sharedvar.TensorSharedVariable))): theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args) return advanced_subtensor1(self, *args)
else: else:
return AdvancedSubtensor(args)(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable))) return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable)))
...@@ -5272,20 +5272,17 @@ class AdvancedSubtensor(Op): ...@@ -5272,20 +5272,17 @@ class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing. """Return a subtensor copy, using advanced indexing.
""" """
# Should be used by __getitem__ and __getslice__, as follow: # Should be used by __getitem__ and __getslice__, as follow:
# AdvancedSubtensor(args)(self, *args), # AdvancedSubtensor()(self, *args),
# if args contains and advanced indexing pattern # if args contains and advanced indexing pattern
def __init__(self, args): # idx_list? def __eq__(self, other):
# For the moment, __init__ will be passed the whole list of arguments return self.__class__ == other.__class__
#TODO: see what's the best solution
self.args = args # ?
#FIXME: do not store variables in the class instance def __hash__(self):
return hash(self.__class__)
#FIXME def __str__(self):
#if len(args) != 2: return self.__class__.__name__
# print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args)
# print >>sys.stderr, ' arguments are:', args
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -5348,14 +5345,21 @@ class AdvancedSubtensor(Op): ...@@ -5348,14 +5345,21 @@ class AdvancedSubtensor(Op):
gz, = grads gz, = grads
x = inputs[0] x = inputs[0]
rest = inputs[1:] rest = inputs[1:]
return [AdvancedIncSubtensor(self.args)(zeros_like(x), gz, *rest)] + [None]*len(rest) return [AdvancedIncSubtensor()(zeros_like(x), gz, *rest)] + [None]*len(rest)
class AdvancedIncSubtensor(Op): class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing. """Increments a subtensor using advanced indexing.
""" """
def __init__(self, args): #idx_list? inplace=False? def __eq__(self, other):
self.args = args return self.__class__ == other.__class__
def __hash__(self):
return hash(self.__class__)
def __str__(self):
return self.__class__.__name__
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -5393,7 +5397,7 @@ class AdvancedIncSubtensor(Op): ...@@ -5393,7 +5397,7 @@ class AdvancedIncSubtensor(Op):
idxs = inpt[2:] idxs = inpt[2:]
outgrad, = output_gradients outgrad, = output_gradients
d_x_wrt_C = outgrad d_x_wrt_C = outgrad
d_y_wrt_C = AdvancedSubtensor(self.args)(outgrad, *idxs) d_y_wrt_C = AdvancedSubtensor()(outgrad, *idxs)
return [d_x_wrt_C, d_y_wrt_C] + [None for _ in idxs] return [d_x_wrt_C, d_y_wrt_C] + [None for _ in idxs]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论