提交 eff08cb3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix as_tensor_variable bug when ndim option drops all dimensions

上级 4e107721
...@@ -127,11 +127,14 @@ def _as_tensor_Variable(x, name, ndim, **kwargs): ...@@ -127,11 +127,14 @@ def _as_tensor_Variable(x, name, ndim, **kwargs):
return x return x
if x.type.ndim > ndim: if x.type.ndim > ndim:
# strip off leading broadcastable dimensions # Strip off leading broadcastable dimensions
first_non_broadcastable = [ non_broadcastables = [idx for idx in range(x.ndim) if not x.broadcastable[idx]]
idx for idx in range(x.ndim) if not x.broadcastable[idx]
][0] if non_broadcastables:
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:]) x = x.dimshuffle(list(range(x.ndim))[non_broadcastables[0] :])
else:
x = x.dimshuffle()
if x.ndim > ndim: if x.ndim > ndim:
raise ValueError( raise ValueError(
f"Tensor of type {x.type} could not be cast to have {ndim} dimensions" f"Tensor of type {x.type} could not be cast to have {ndim} dimensions"
......
...@@ -447,14 +447,20 @@ class TestAsTensorVariable: ...@@ -447,14 +447,20 @@ class TestAsTensorVariable:
with pytest.raises(ValueError): with pytest.raises(ValueError):
as_tensor_variable(bad_apply_var) as_tensor_variable(bad_apply_var)
def test_strip_leading_broadcastable(self): def test_ndim_strip_leading_broadcastable(self):
x = TensorType(config.floatX, (True, False))("x") x = TensorType(config.floatX, (True, False))("x")
x = as_tensor_variable(x, ndim=1) x = as_tensor_variable(x, ndim=1)
assert x.ndim == 1 assert x.ndim == 1
x = matrix("x", dtype=config.floatX) def test_ndim_all_broadcastable(self):
with pytest.raises(ValueError): x = TensorType(config.floatX, (True, True))("x")
as_tensor_variable(x, ndim=1) res = as_tensor_variable(x, ndim=0)
assert res.ndim == 0
def test_ndim_incompatible(self):
x = TensorType(config.floatX, (True, False))("x")
with pytest.raises(ValueError, match="^Tensor of type.*"):
as_tensor_variable(x, ndim=0)
def test_bool(self): def test_bool(self):
# We should not allow `as_tensor_variable` to accept `True` or `False`, # We should not allow `as_tensor_variable` to accept `True` or `False`,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论