提交 f671bca6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3202 from craffel/symbolic_slicing

Allow slicing of tensor variables to return symbolic shape refs
......@@ -5,6 +5,7 @@ import warnings
import numpy
from six.moves import xrange
import numbers
import theano
from theano.compat import izip
......@@ -3922,15 +3923,22 @@ def get_vector_length(v):
return len(v.owner.inputs)
if v.owner and isinstance(v.owner.op, Shape):
return v.owner.inputs[0].type.ndim
# If we take this slice: var[:0], we know it will have 0 elements.
# If we take a slice, we know how many elements it will result in
if ((v.owner and
isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
isinstance(v.owner.op.idx_list[0], slice) and
v.owner.op.idx_list[0].start in [None, 0])):
stop = theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].stop
if extract_constant(stop) == 0:
return 0
isinstance(v.owner.op.idx_list[0], slice))):
start = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].start)
stop = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].stop)
if start is None:
start = 0
if stop is None:
stop = 0
if ((isinstance(stop, numbers.Integral) and
isinstance(start, numbers.Integral))):
return stop - start
raise ValueError("length not known")
......
......@@ -7729,6 +7729,13 @@ def test_allocempty():
assert out.shape == (2, 3)
assert out.dtype == 'float32'
def test_symbolic_slice():
x = theano.tensor.tensor4('x')
a, b = x.shape[:2]
output = a.eval({x: numpy.zeros((5, 4, 3, 2), dtype=theano.config.floatX)})
assert output == numpy.array(5)
"""
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论