提交 e829f389 authored 作者: James Bergstra's avatar James Bergstra

added set_instead_of_inc to AdvSubtensor1, and fixed bug in eq and hash

上级 ee8fcd4f
......@@ -3505,9 +3505,10 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1]
if set_instead_of_inc:
raise NotImplementedError()
the_op = AdvancedIncSubtensor1(inplace, set_instead_of_inc=True)
else:
return AdvancedIncSubtensor1(inplace)(real_x, y, ilist)
the_op = AdvancedIncSubtensor1(inplace, set_instead_of_inc=False)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor):
raise NotImplementedError()
else:
......@@ -4892,15 +4893,19 @@ advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(Op):
"""Increments a subtensor using advanced slicing (list of index)"""
def __init__(self, inplace=False):
def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = inplace
self.set_instead_of_inc = set_instead_of_inc
if inplace:
self.destroy_map = {0: [0]}
def __hash__(self):
return hash(type(self))
return hash((type(self), self.inplace, self.set_instead_of_inc))
def __eq__(self, other):
return type(self) == type(other)
return (type(self) == type(other)
and self.inplace == other.inplace
and self.set_instead_of_inc == other.set_instead_of_inc)
def make_node(self, x, y, ilist):
x_ = as_tensor_variable(x)
......@@ -4927,6 +4932,14 @@ class AdvancedIncSubtensor1(Op):
# x[idx] += y don't work if the same index is present many times.
# It do it only once
# -- Numpy also behaves this way, is it a bug in numpy?
if self.set_instead_of_inc:
if y.ndim:
for (j,i) in enumerate(idx):
x[i] = y[j]
else:
for i in idx:
x[i] = y
else:
if y.ndim:
for (j,i) in enumerate(idx):
x[i] += y[j]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论