提交 252c57ca authored 作者: Frederic Bastien's avatar Frederic Bastien

small opt in subtensor.

上级 24102962
...@@ -2589,9 +2589,9 @@ def get_idx_list(inputs, idx_list): ...@@ -2589,9 +2589,9 @@ def get_idx_list(inputs, idx_list):
''' '''
# The subtensor (or idx_list) does not depend on the inputs. # The subtensor (or idx_list) does not depend on the inputs.
indices = list(reversed(list(inputs[1:]))) if len(inputs) == 1:
if len(indices) == 0:
return tuple(idx_list) return tuple(idx_list)
indices = list(reversed(list(inputs[1:])))
# General case # General case
def convert(entry): def convert(entry):
...@@ -2779,7 +2779,7 @@ class Subtensor(Op): ...@@ -2779,7 +2779,7 @@ class Subtensor(Op):
raise TypeError(Subtensor.e_indextype, entry) raise TypeError(Subtensor.e_indextype, entry)
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list) self.idx_list = tuple(map(self.convert, idx_list))
self.perform_cache_cdata = None self.perform_cache_cdata = None
@staticmethod @staticmethod
...@@ -2834,7 +2834,7 @@ class Subtensor(Op): ...@@ -2834,7 +2834,7 @@ class Subtensor(Op):
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
# (first call caches cdata here) # (first call caches cdata here)
if len(inputs[1:]) == 0: if len(inputs) == 1:
self.perform_cache_cdata = cdata self.perform_cache_cdata = cdata
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论