提交 e74f035c authored 作者: Shawn Tan's avatar Shawn Tan

Modifications based on review.

上级 a099b54c
......@@ -593,13 +593,14 @@ class GpuAdvancedIncSubtensor(HideC, tensor.AdvancedIncSubtensor):
"""
def make_node(self, x, y, *inputs):
ctx_name = infer_context_name(x)
ctx_name = infer_context_name(x, y)
rval = tensor.AdvancedIncSubtensor.make_node(self, x, y, *inputs)
otype = GpuArrayType(dtype=rval.outputs[0].type.dtype,
broadcastable=rval.outputs[0].type.broadcastable,
context_name=ctx_name)
x = as_gpuarray_variable(x, ctx_name)
return gof.Apply(self, [x] + rval.inputs[1:], [otype()])
y = as_gpuarray_variable(y, ctx_name)
return gof.Apply(self, [x, y] + rval.inputs[2:], [otype()])
# We can't use the parent version that loops on each index
# as we also need to loop when set_instead_of_inc is True and the
......@@ -673,8 +674,6 @@ class GpuAdvancedIncSubtensor(HideC, tensor.AdvancedIncSubtensor):
# build the indices and use it
take_idx = sum((i * s for i, s in zip(nidx, strides))).flatten()
k = get_iadd(node.inputs[0], node.inputs[1])
y_flat = pygpu.asarray(y_flat, context=x_flat.context)
for j, i in enumerate(take_idx):
k(x_flat[i], y_flat[j], broadcast=True)
out[0] = x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论