提交 ed85af6c authored 作者: James Bergstra's avatar James Bergstra

new impl of GetSubtensor uses cuda_ndarray mapping protocol

上级 cee77d13
......@@ -375,6 +375,25 @@ class GpuSubtensor(tensor.Subtensor):
return rval
def perform(self, node, inputs, (out, )):
x = inputs[0]
indices = list(reversed(inputs[1:]))
def convert(entry):
if isinstance(entry, Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
out[0] = x.__getitem__(cdata)
def old_perform(self, node, inputs, (out, )):
indices = list(reversed(inputs[1:]))
def convert(entry):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论