提交 676308df authored 作者: John Salvatier's avatar John Salvatier

fixed problems nouiz pointed out

上级 61cb748b
...@@ -7227,6 +7227,10 @@ class MakeSlice(Op): ...@@ -7227,6 +7227,10 @@ class MakeSlice(Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def grad(self, inputs, grads):
return [DiconnectedType()() for i in inputs]
make_slice = MakeSlice() make_slice = MakeSlice()
...@@ -7257,7 +7261,13 @@ slicetype = SliceType() ...@@ -7257,7 +7261,13 @@ slicetype = SliceType()
NoneConst = Constant(NoneTypeT(), None, name = 'None') NoneConst = Constant(NoneTypeT(), None, name = 'None')
def adv_broadcastable(a, idx): def adv_index_broadcastable_pattern(a, idx):
"""
This function is only used to determine the broardcast pattern for AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy the output. From this, we find the output broadcast pattern.
"""
def replace_slice(v): def replace_slice(v):
if isinstance(v, gof.Apply): if isinstance(v, gof.Apply):
if len(v.outputs) != 1: if len(v.outputs) != 1:
...@@ -7267,7 +7277,7 @@ def adv_broadcastable(a, idx): ...@@ -7267,7 +7277,7 @@ def adv_broadcastable(a, idx):
else: else:
v = v.outputs[0] v = v.outputs[0]
if v is NoneConst: if NoneConst.equals(v):
return None return None
if isinstance(v.type, SliceType): if isinstance(v.type, SliceType):
return slice(None,None) return slice(None,None)
...@@ -7306,7 +7316,7 @@ class AdvancedSubtensor(Op): ...@@ -7306,7 +7316,7 @@ class AdvancedSubtensor(Op):
return gof.Apply(self, return gof.Apply(self,
(x,) + index, (x,) + index,
[tensor(dtype = x.type.dtype, [tensor(dtype = x.type.dtype,
broadcastable = adv_broadcastable(x, index) )]) broadcastable = adv_index_broadcastable_pattern(x, index) )])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -7410,6 +7420,8 @@ class AdvancedIncSubtensor(Op): ...@@ -7410,6 +7420,8 @@ class AdvancedIncSubtensor(Op):
out, = out_ out, = out_
if not self.inplace: if not self.inplace:
out[0] = inputs[0].copy() out[0] = inputs[0].copy()
else:
out[0] = inputs[0]
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论