SubTensor: make perform faster when idx_list contains only constants (thank you Pascal).

上级 8ea6f992
......@@ -2133,6 +2133,7 @@ class Subtensor(Op):
def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list)
self.perform_cache_cdata = None
@staticmethod
def my_as_scalar(a):
......@@ -2172,21 +2173,38 @@ class Subtensor(Op):
def perform(self, node, inputs, (out, )):
x = inputs[0]
# The subtensor (or idx_list) does not depend on the inputs.
# (and cdata was cached on initial call)
if self.perform_cache_cdata is not None:
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return
indices = list(reversed(inputs[1:]))
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
# The subtensor (or idx_list) does not depend on the inputs.
# (first call caches cdata here)
if len(indices) == 0:
cdata = tuple(self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
self.perform_cache_cdata = cdata
# General case
else:
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
else:
return entry
cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata))
def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论