提交 0d6b1a1c authored 作者: James Bergstra's avatar James Bergstra

tensor.basic - added set_subtensor and inc_subtensor functions

上级 757d0483
...@@ -2078,6 +2078,7 @@ def pow(a, b): ...@@ -2078,6 +2078,7 @@ def pow(a, b):
def clip(x, min, max): def clip(x, min, max):
"""clip x to be between min and max""" """clip x to be between min and max"""
# see decorator for function body # see decorator for function body
# for grep: clamp, bound
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either')) pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either')) pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
...@@ -2371,31 +2372,48 @@ class SubtensorPrinter: ...@@ -2371,31 +2372,48 @@ class SubtensorPrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter()) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter())
def setsubtensor(x, y, idx_list, inplace=False): def setsubtensor(x, y, idx_list, inplace=False):
""" print >> sys.stderr, "tensor.setsubtensor is deprecated - please use set_subtensor"
setsubtensor is meant to replicate the following numpy behaviour: x[i,j,k] = y the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc=True)
return the_op(x, y, *Subtensor.collapse(idx_list, lambda entry: isinstance(entry, Variable)))
def incsubtensor(x, y, idx_list, inplace=False):
print >> sys.stderr, "tensor.setsubtensor is deprecated - please use set_subtensor"
the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc=False)
return the_op(x, y, *Subtensor.collapse(idx_list, lambda entry: isinstance(entry, Variable)))
def set_subtensor(x, y):
"""Return x with the given subtensor overwritten by y.
Example: To replicate the numpy expression "r[10:] = 5", type
>>> new_r = set_subtensor(r[10:], 5)
:param x: symbolic variable for the lvalue of = operation :param x: symbolic variable for the lvalue of = operation
:param y: symbolic variable for the rvalue of = operation :param y: symbolic variable for the rvalue of = operation
:param idx_list: tuple of length x.dim, containing indices with which to index x.
:param inplace: boolean to declare whether the operation is in place or not (False unless
called from within an optimization)
:Details: idx_list can be a tuple containing a mixture of numeric constants, symbolic :see: theano.tensor.basic.setsubtensor
scalar values and standard numpy slice objects. i.e:
idx_list=(1,2,3), idx_list=(1,b,3) where b is an iscalar variable,
idx_list=(slice(start,stop,step),b,3) equivalent to x[start:stop:step, b, 3]
""" """
the_op = IncSubtensor(idx_list, inplace, True) return inc_subtensor(x, y, set_instead_of_inc=True)
return the_op(x, y, *Subtensor.collapse(idx_list, lambda entry: isinstance(entry, Variable)))
def incsubtensor(x, y, idx_list, inplace=False): def inc_subtensor(x, y, set_instead_of_inc=False):
""" """Return x with the given subtensor incremented by y.
incsubtensor is meant to replicate the following numpy behaviour: x[i,j,k] += y
:param x: the symbolic result of a Subtensor operation.
:param y: the amount by which to increment ths subtensor in question
Example: To replicate the numpy expression "r[10:] += 5", type
>>> new_r = inc_subtensor(r[10:], 5)
:see: theano.tensor.basic.setsubtensor :see: theano.tensor.basic.setsubtensor
""" """
the_op = IncSubtensor(idx_list, inplace, False) # retrieve idx_list from x.owner
return the_op(x, y, *Subtensor.collapse(idx_list, lambda entry: isinstance(entry, Variable))) if not isinstance(x.owner.op, Subtensor):
raise TypeError('x must be result of a subtensor operation')
the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc)
real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs)
class IncSubtensor(Op): class IncSubtensor(Op):
"""Increment a subtensor. """Increment a subtensor.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论