提交 7e219e9c authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Use MakeVector in as_tensor_variable

上级 b83d0a64
......@@ -1111,6 +1111,18 @@ class TestAsTensorVariable:
a_vector = as_tensor_variable(x_vector)
assert x_vector is a_vector
def test_make_vector(self):
a = tt.iscalar()
x = tt.tile(a, (1, 1, 1))
y = (tt.constant(1, dtype="int64"), x.shape[2])
res = tt.as_tensor(y, ndim=1)
assert isinstance(res.owner.op, tt.opt.MakeVector)
assert tuple(res.owner.inputs) == y
y = (1, x.shape[2])
res = tt.as_tensor(y)
assert isinstance(res.owner.op, tt.opt.MakeVector)
class TestAlloc:
dtype = config.floatX
......
......@@ -101,7 +101,7 @@ def as_tensor_variable(x, name=None, ndim=None):
Parameters
----------
x : Apply instance, Variable instance, numpy.ndarray, or number
x : Apply or Variable or numpy.ndarray or number
This thing will be transformed into a `Variable` in a sensible way. An
ndarray argument will not be copied, but a list of numbers will be
copied to make an ndarray.
......@@ -185,6 +185,16 @@ def as_tensor_variable(x, name=None, ndim=None):
try:
x = [extract_constants(i) for i in x]
except TypeError:
if builtins.all(getattr(i, "ndim", None) == 0 for i in x) and (
ndim is None or ndim == 1
):
# In this instance, we can avoid making a `Join` `Op`, because
# we know that the result should be a vector.
# `MakeVector` is a better option due to its `get_scalar_constant_value`
# support.
dtype = scal.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return theano.tensor.opt.MakeVector(dtype)(*x)
return stack(x)
elif isinstance(x, bool):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论