提交 0926a6c5 authored 作者: Frederic's avatar Frederic

Add support for step

上级 f31cc48f
...@@ -4250,16 +4250,23 @@ def get_vector_length(v): ...@@ -4250,16 +4250,23 @@ def get_vector_length(v):
v.owner.inputs, v.owner.op.idx_list)[0].start) v.owner.inputs, v.owner.op.idx_list)[0].start)
stop = extract_constant(theano.tensor.subtensor.get_idx_list( stop = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].stop) v.owner.inputs, v.owner.op.idx_list)[0].stop)
step = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].step)
if start is None: if start is None:
start = 0 start = 0
ndim = v.owner.inputs[0].owner.inputs[0].ndim ndim = v.owner.inputs[0].owner.inputs[0].ndim
types = (numbers.Integral, numpy.integer)
if stop is None: if stop is None:
stop = ndim stop = ndim
elif isinstance(stop, numbers.Integral) and stop > ndim : elif isinstance(stop, types) and stop > ndim:
stop = ndim stop = ndim
if ((isinstance(stop, numbers.Integral) and if step is None:
isinstance(start, numbers.Integral))): step = 1
return stop - start
if (isinstance(stop, types) and
isinstance(start, types) and
start >= 0 and stop >= 0):
return (stop - start - 1) // step + 1
if isinstance(v, Variable): if isinstance(v, Variable):
msg = theano.printing.debugprint(v, file='str') msg = theano.printing.debugprint(v, file='str')
else: else:
......
...@@ -3354,6 +3354,7 @@ class T_GetVectorLength(unittest.TestCase): ...@@ -3354,6 +3354,7 @@ class T_GetVectorLength(unittest.TestCase):
assert len(list(x.shape[1:4])) == 3 assert len(list(x.shape[1:4])) == 3
assert len(list(x.shape[1:5])) == 3 assert len(list(x.shape[1:5])) == 3
assert len(list(x.shape[1:10])) == 3 assert len(list(x.shape[1:10])) == 3
assert len(list(x.shape[1:10:2])) == 2
class T_Join_and_Split(unittest.TestCase): class T_Join_and_Split(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论