提交 74efc965 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(jax): Specify shape should ignore None axes

上级 790b46fd
......@@ -96,12 +96,11 @@ def jax_funcify_Shape_i(op, **kwargs):
def jax_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert x.shape == tuple(shape), (
"got shape",
x.shape,
"expected",
shape,
)
for actual, expected in zip(x.shape, shape):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
return x
return specifyshape
......
......@@ -25,7 +25,7 @@ def test_jax_shape_ops():
def test_jax_specify_shape():
in_at = at.matrix("in")
x = at.specify_shape(in_at, (4, 5))
x = at.specify_shape(in_at, (4, None))
x_fg = FunctionGraph([in_at], [x])
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论