Unverified 提交 9e603cf4 authored 作者: Abhinav Khot's avatar Abhinav Khot 提交者: GitHub

Provide static output shape for constant arange

上级 cb0758c1
......@@ -3215,13 +3215,29 @@ class ARange(Op):
self.dtype = dtype
def make_node(self, start, stop, step):
from math import ceil
start, stop, step = map(as_tensor_variable, (start, stop, step))
assert start.ndim == 0
assert stop.ndim == 0
assert step.ndim == 0
# if it is possible to directly determine the shape i.e static shape is present, we find it.
if (
isinstance(start, TensorConstant)
and isinstance(stop, TensorConstant)
and isinstance(step, TensorConstant)
):
length = max(
ceil((float(stop.data) - float(start.data)) / float(step.data)), 0
)
shape = (length,)
else:
shape = (None,)
inputs = [start, stop, step]
outputs = [tensor(dtype=self.dtype, shape=(None,))]
outputs = [tensor(dtype=self.dtype, shape=shape)]
return Apply(self, inputs, outputs)
......
......@@ -2861,6 +2861,13 @@ class TestARange:
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(0) == len(np.arange(0, 0)))
def test_static_shape(self):
assert np.arange(1, 10).shape == arange(1, 10).type.shape
assert np.arange(10, 1, -1).shape == arange(10, 1, -1).type.shape
assert np.arange(1, -9, 2).shape == arange(1, -9, 2).type.shape
assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape
assert np.arange(-64, 64).shape == arange(-64, 64).type.shape
class TestNdGrid:
def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论