SubTensor: make perform faster when idx_list contains only constants (thank you Pascal).

上级 8ea6f992
...@@ -2133,6 +2133,7 @@ class Subtensor(Op): ...@@ -2133,6 +2133,7 @@ class Subtensor(Op):
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list) self.idx_list = map(self.convert, idx_list)
self.perform_cache_cdata = None
@staticmethod @staticmethod
def my_as_scalar(a): def my_as_scalar(a):
...@@ -2172,8 +2173,24 @@ class Subtensor(Op): ...@@ -2172,8 +2173,24 @@ class Subtensor(Op):
def perform(self, node, inputs, (out, )): def perform(self, node, inputs, (out, )):
x = inputs[0] x = inputs[0]
# The subtensor (or idx_list) does not depend on the inputs.
# (and cdata was cached on initial call)
if self.perform_cache_cdata is not None:
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return
indices = list(reversed(inputs[1:])) indices = list(reversed(inputs[1:]))
# The subtensor (or idx_list) does not depend on the inputs.
# (first call caches cdata here)
if len(indices) == 0:
cdata = tuple(self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
self.perform_cache_cdata = cdata
# General case
else:
def convert(entry): def convert(entry):
if isinstance(entry, gof.Type): if isinstance(entry, gof.Type):
return indices.pop() return indices.pop()
...@@ -2183,10 +2200,11 @@ class Subtensor(Op): ...@@ -2183,10 +2200,11 @@ class Subtensor(Op):
convert(entry.step)) convert(entry.step))
else: else:
return entry return entry
cdata = tuple(map(convert, self.idx_list)) cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论