提交 8b7ae921 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix two small bugs and improve doc for get_constant_idx().

上级 518b6f43
...@@ -1649,7 +1649,7 @@ def local_useless_subtensor(node): ...@@ -1649,7 +1649,7 @@ def local_useless_subtensor(node):
pass pass
if isinstance(idx.stop, (int, numpy.integer)): if isinstance(idx.stop, (int, numpy.integer)):
if stop < length_pos_data: if idx.stop < length_pos_data:
return False return False
elif isinstance(idx.stop, gof.Variable): elif isinstance(idx.stop, gof.Variable):
length_pos_shape_i = idx.stop length_pos_shape_i = idx.stop
......
...@@ -374,6 +374,16 @@ class Subtensor(Op): ...@@ -374,6 +374,16 @@ class Subtensor(Op):
exception. exception.
None entries are always left as-is. None entries are always left as-is.
Example usage (where v, a are appropriately typed theano variables):
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(Scalar(int64), slice(Scalar(int64), Scalar(int64), None))
>>> b.owner.op.get_constant_idx(b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v
""" """
real_idx = get_idx_list(inputs, self.idx_list) real_idx = get_idx_list(inputs, self.idx_list)
def conv(val): def conv(val):
...@@ -444,7 +454,9 @@ class Subtensor(Op): ...@@ -444,7 +454,9 @@ class Subtensor(Op):
start = p.start start = p.start
if start is None: if start is None:
start = 0 start = 0
if p.stop is None or p.stop > start: if (p.stop is None or
(isinstance(p.stop, (int, numpy.integer)) and
p.stop > start)):
broadcastable.append(True) broadcastable.append(True)
continue continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论