Modified IncSubtensor so that it can now perform setting in addition

to incrementing.
上级 38d38f11
......@@ -2006,6 +2006,9 @@ class SubtensorPrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter())
def incsubtensor(x, y, idx_list, inplace=False, set_instead_of_inc=False):
the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc)
return the_op(x, y)
class IncSubtensor(Op):
"""Increment a subtensor.
......@@ -2015,18 +2018,23 @@ class IncSubtensor(Op):
z[i,j,k] += <something>
It is used internally to implement the gradient on SubTensor.
:param set_instead_of_inc: if True set the subtensor to the value instead
of incrementing it by that value.
"""
def __init__(self, idx_list, inplace=False):
def __init__(self, idx_list, inplace=False, set_instead_of_inc=False):
self.idx_list = map(Subtensor.convert, idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.set_instead_of_inc = set_instead_of_inc
def __eq__(self, other):
return type(self) == type(other) \
and self.idx_list == other.idx_list \
and self.inplace == other.inplace
and self.inplace == other.inplace \
and self.set_instead_of_inc == other.set_instead_of_inc
def __hash__(self):
msg = []
......@@ -2042,7 +2050,8 @@ class IncSubtensor(Op):
# if isinstance(entry, slice)
# else entry
# for entry in self.idx_list)
return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace)
return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) \
^ hash(self.set_instead_of_inc)
def __str__(self):
indices = []
......@@ -2055,6 +2064,10 @@ class IncSubtensor(Op):
msg = 'Inplace'
else:
msg = ''
if not self.set_instead_of_inc:
msg += 'Inc'
else:
msg += 'Set'
return "%s%s{%s}" % (msg,
self.__class__.__name__, ", ".join(indices))
......@@ -2107,10 +2120,17 @@ class IncSubtensor(Op):
sub_x = x.__getitem__(cdata)
if sub_x.shape:
# we've sliced out an N-D tensor with N > 0
if not self.set_instead_of_inc:
sub_x += y
else:
#sub_x += -sub_x + y
x.__setitem__(cdata, y)
else:
# scalar case
if not self.set_instead_of_inc:
x.__setitem__(cdata, sub_x + y)
else:
x.__setitem__(cdata, y)
out[0] = x
def split(x, splits_size, n_splits, axis=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论