提交 e829f389 authored 作者: James Bergstra's avatar James Bergstra

added set_instead_of_inc to AdvSubtensor1, and fixed bug in eq and hash

上级 ee8fcd4f
...@@ -3505,9 +3505,10 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -3505,9 +3505,10 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1] ilist = x.owner.inputs[1]
if set_instead_of_inc: if set_instead_of_inc:
raise NotImplementedError() the_op = AdvancedIncSubtensor1(inplace, set_instead_of_inc=True)
else: else:
return AdvancedIncSubtensor1(inplace)(real_x, y, ilist) the_op = AdvancedIncSubtensor1(inplace, set_instead_of_inc=False)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor): elif isinstance(x.owner.op, AdvancedSubtensor):
raise NotImplementedError() raise NotImplementedError()
else: else:
...@@ -4892,15 +4893,19 @@ advanced_subtensor1 = AdvancedSubtensor1() ...@@ -4892,15 +4893,19 @@ advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(Op): class AdvancedIncSubtensor1(Op):
"""Increments a subtensor using advanced slicing (list of index)""" """Increments a subtensor using advanced slicing (list of index)"""
def __init__(self, inplace=False): def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = inplace self.inplace = inplace
self.set_instead_of_inc = set_instead_of_inc
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash((type(self), self.inplace, self.set_instead_of_inc))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return (type(self) == type(other)
and self.inplace == other.inplace
and self.set_instead_of_inc == other.set_instead_of_inc)
def make_node(self, x, y, ilist): def make_node(self, x, y, ilist):
x_ = as_tensor_variable(x) x_ = as_tensor_variable(x)
...@@ -4927,6 +4932,14 @@ class AdvancedIncSubtensor1(Op): ...@@ -4927,6 +4932,14 @@ class AdvancedIncSubtensor1(Op):
# x[idx] += y don't work if the same index is present many times. # x[idx] += y don't work if the same index is present many times.
# It do it only once # It do it only once
# -- Numpy also behaves this way, is it a bug in numpy? # -- Numpy also behaves this way, is it a bug in numpy?
if self.set_instead_of_inc:
if y.ndim:
for (j,i) in enumerate(idx):
x[i] = y[j]
else:
for i in idx:
x[i] = y
else:
if y.ndim: if y.ndim:
for (j,i) in enumerate(idx): for (j,i) in enumerate(idx):
x[i] += y[j] x[i] += y[j]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论