Modified IncSubtensor so that it can now perform setting in addition

to incrementing.
上级 38d38f11
...@@ -2006,6 +2006,9 @@ class SubtensorPrinter: ...@@ -2006,6 +2006,9 @@ class SubtensorPrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter()) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter())
def incsubtensor(x, y, idx_list, inplace=False, set_instead_of_inc=False):
the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc)
return the_op(x, y)
class IncSubtensor(Op): class IncSubtensor(Op):
"""Increment a subtensor. """Increment a subtensor.
...@@ -2015,18 +2018,23 @@ class IncSubtensor(Op): ...@@ -2015,18 +2018,23 @@ class IncSubtensor(Op):
z[i,j,k] += <something> z[i,j,k] += <something>
It is used internally to implement the gradient on SubTensor. It is used internally to implement the gradient on SubTensor.
:param set_instead_of_inc: if True set the subtensor to the value instead
of incrementing it by that value.
""" """
def __init__(self, idx_list, inplace=False): def __init__(self, idx_list, inplace=False, set_instead_of_inc=False):
self.idx_list = map(Subtensor.convert, idx_list) self.idx_list = map(Subtensor.convert, idx_list)
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.set_instead_of_inc = set_instead_of_inc
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) \ return type(self) == type(other) \
and self.idx_list == other.idx_list \ and self.idx_list == other.idx_list \
and self.inplace == other.inplace and self.inplace == other.inplace \
and self.set_instead_of_inc == other.set_instead_of_inc
def __hash__(self): def __hash__(self):
msg = [] msg = []
...@@ -2042,7 +2050,8 @@ class IncSubtensor(Op): ...@@ -2042,7 +2050,8 @@ class IncSubtensor(Op):
# if isinstance(entry, slice) # if isinstance(entry, slice)
# else entry # else entry
# for entry in self.idx_list) # for entry in self.idx_list)
return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) \
^ hash(self.set_instead_of_inc)
def __str__(self): def __str__(self):
indices = [] indices = []
...@@ -2055,6 +2064,10 @@ class IncSubtensor(Op): ...@@ -2055,6 +2064,10 @@ class IncSubtensor(Op):
msg = 'Inplace' msg = 'Inplace'
else: else:
msg = '' msg = ''
if not self.set_instead_of_inc:
msg += 'Inc'
else:
msg += 'Set'
return "%s%s{%s}" % (msg, return "%s%s{%s}" % (msg,
self.__class__.__name__, ", ".join(indices)) self.__class__.__name__, ", ".join(indices))
...@@ -2107,10 +2120,17 @@ class IncSubtensor(Op): ...@@ -2107,10 +2120,17 @@ class IncSubtensor(Op):
sub_x = x.__getitem__(cdata) sub_x = x.__getitem__(cdata)
if sub_x.shape: if sub_x.shape:
# we've sliced out an N-D tensor with N > 0 # we've sliced out an N-D tensor with N > 0
sub_x += y if not self.set_instead_of_inc:
sub_x += y
else:
#sub_x += -sub_x + y
x.__setitem__(cdata, y)
else: else:
# scalar case # scalar case
x.__setitem__(cdata, sub_x + y) if not self.set_instead_of_inc:
x.__setitem__(cdata, sub_x + y)
else:
x.__setitem__(cdata, y)
out[0] = x out[0] = x
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, n_splits, axis=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论