提交 b60cf724 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix static shape inference in DimShuffle

上级 ea5401cc
......@@ -202,14 +202,14 @@ class DimShuffle(ExternalCOp):
# else, expected == b or expected is False and b is True
# Both case are good.
ob = []
for value in self.new_order:
if value == "x":
ob.append(True)
out_static_shape = []
for dim_idx in self.new_order:
if dim_idx == "x":
out_static_shape.append(1)
else:
ob.append(ib[value])
out_static_shape.append(input.type.shape[dim_idx])
output = TensorType(dtype=input.type.dtype, shape=ob)()
output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()
return Apply(self, [input], [output])
......
......@@ -27,6 +27,7 @@ from aesara.tensor.type import (
discrete_dtypes,
matrix,
scalar,
tensor,
vector,
vectors,
)
......@@ -178,6 +179,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
tracemalloc.stop()
assert np.allclose(np.mean(block_diffs), 0)
def test_static_shape(self):
x = tensor(np.float64, shape=(1, 2), name="x")
y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1)
class TestBroadcast:
# this is to allow other types to reuse this class to test their ops
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论