提交 d9b990de authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added support for slices with constant start/stop/step variables in local_subtensor_make_vector

上级 b1259699
...@@ -1383,11 +1383,18 @@ def local_subtensor_make_vector(node): ...@@ -1383,11 +1383,18 @@ def local_subtensor_make_vector(node):
return [make_vector(*[x.owner.inputs[v] for v in values])] return [make_vector(*[x.owner.inputs[v] for v in values])]
else: else:
raise TypeError('case not expected') raise TypeError('case not expected')
else: elif isinstance(idx, slice):
# it is a slice of ints and/or Variables # it is a slice of ints and/or Variables
#TODO: check subtensor to see if it can contain # check subtensor to see if it can contain constant variables, and if
# constant variables, and if it can, then try to # it can, then try to unpack them.
# unpack them. try:
const_slice = node.op.get_constant_idx(node.inputs,
allow_partial=False)[0]
return [make_vector(*x.owner.inputs[const_slice])]
except NotScalarConstantError:
pass
# there was at least one non-constant variable in the slice
try: try:
return [make_vector(*x.owner.inputs.__getitem__(idx))] return [make_vector(*x.owner.inputs.__getitem__(idx))]
except TypeError: except TypeError:
...@@ -1395,6 +1402,8 @@ def local_subtensor_make_vector(node): ...@@ -1395,6 +1402,8 @@ def local_subtensor_make_vector(node):
except Exception: except Exception:
_logger.error('failed to index with "%s"' % str(idx)) _logger.error('failed to index with "%s"' % str(idx))
raise raise
else:
raise TypeError('case not expected')
#TODO: the other optimization for and, or, xor, le and ge see ticket #496. #TODO: the other optimization for and, or, xor, le and ge see ticket #496.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论