提交 676308df authored 作者: John Salvatier's avatar John Salvatier

fixed problems nouiz pointed out

上级 61cb748b
......@@ -7227,6 +7227,10 @@ class MakeSlice(Op):
def __hash__(self):
return hash(type(self))
def grad(self, inputs, grads):
return [DiconnectedType()() for i in inputs]
make_slice = MakeSlice()
......@@ -7257,7 +7261,13 @@ slicetype = SliceType()
NoneConst = Constant(NoneTypeT(), None, name = 'None')
def adv_broadcastable(a, idx):
def adv_index_broadcastable_pattern(a, idx):
"""
This function is only used to determine the broardcast pattern for AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy the output. From this, we find the output broadcast pattern.
"""
def replace_slice(v):
if isinstance(v, gof.Apply):
if len(v.outputs) != 1:
......@@ -7267,7 +7277,7 @@ def adv_broadcastable(a, idx):
else:
v = v.outputs[0]
if v is NoneConst:
if NoneConst.equals(v):
return None
if isinstance(v.type, SliceType):
return slice(None,None)
......@@ -7306,7 +7316,7 @@ class AdvancedSubtensor(Op):
return gof.Apply(self,
(x,) + index,
[tensor(dtype = x.type.dtype,
broadcastable = adv_broadcastable(x, index) )])
broadcastable = adv_index_broadcastable_pattern(x, index) )])
def R_op(self, inputs, eval_points):
......@@ -7410,6 +7420,8 @@ class AdvancedIncSubtensor(Op):
out, = out_
if not self.inplace:
out[0] = inputs[0].copy()
else:
out[0] = inputs[0]
if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论