提交 27ecbc00 authored 作者: Colin Raffel's avatar Colin Raffel

Allow slicing of tensor variables to return symbolic shape refs

上级 6a39a8e3
...@@ -5,6 +5,7 @@ import warnings ...@@ -5,6 +5,7 @@ import warnings
import numpy import numpy
from six.moves import xrange from six.moves import xrange
import numbers
import theano import theano
from theano.compat import izip from theano.compat import izip
...@@ -3922,15 +3923,22 @@ def get_vector_length(v): ...@@ -3922,15 +3923,22 @@ def get_vector_length(v):
return len(v.owner.inputs) return len(v.owner.inputs)
if v.owner and isinstance(v.owner.op, Shape): if v.owner and isinstance(v.owner.op, Shape):
return v.owner.inputs[0].type.ndim return v.owner.inputs[0].type.ndim
# If we take this slice: var[:0], we know it will have 0 elements. # If we take a slice, we know how many elements it will result in
if ((v.owner and if ((v.owner and
isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
isinstance(v.owner.op.idx_list[0], slice) and isinstance(v.owner.op.idx_list[0], slice))):
v.owner.op.idx_list[0].start in [None, 0])): start = extract_constant(theano.tensor.subtensor.get_idx_list(
stop = theano.tensor.subtensor.get_idx_list( v.owner.inputs, v.owner.op.idx_list)[0].start)
v.owner.inputs, v.owner.op.idx_list)[0].stop stop = extract_constant(theano.tensor.subtensor.get_idx_list(
if extract_constant(stop) == 0: v.owner.inputs, v.owner.op.idx_list)[0].stop)
return 0 if start is None:
start = 0
if stop is None:
stop = 0
if ((isinstance(stop, numbers.Integral) and
isinstance(start, numbers.Integral))):
return stop - start
raise ValueError("length not known") raise ValueError("length not known")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论