提交 10a4b9f0 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Raise if inputs to vectorize_graph are not variables

They would be silently ignored otherwise
上级 ef5592ff
......@@ -294,6 +294,12 @@ def vectorize_graph(
else:
seq_outputs = [outputs]
if not all(
isinstance(key, Variable) and isinstance(value, Variable)
for key, value in replace.items()
):
raise ValueError(f"Some of the replaced items are not Variables: {replace}")
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]
......
......@@ -324,3 +324,35 @@ class TestVectorizeGraph:
new_beta1 * pt.exp(new_beta0 + 1),
]
assert equal_computations(new_outs, expected_new_outs)
def test_non_variable_raises(self):
x = pt.scalar("x", dtype=int)
y = pt.scalar("y", dtype=int)
non_variable_shape = (x, y)
variable_shape = pt.as_tensor(non_variable_shape)
non_variable_shape_out = pt.zeros(non_variable_shape)
variable_shape_out = pt.zeros(variable_shape)
non_variable_batch_shape = (non_variable_shape, non_variable_shape)
variable_batch_shape = pt.stacklists(non_variable_batch_shape)
msg = r"Some of the replaced items are not Variables"
with pytest.raises(ValueError, match=msg):
vectorize_graph(
non_variable_shape_out, {non_variable_shape: non_variable_batch_shape}
)
with pytest.raises(ValueError, match=msg):
vectorize_graph(
variable_shape_out, {variable_shape: non_variable_batch_shape}
)
batch_out = vectorize_graph(
variable_shape_out, {variable_shape: variable_batch_shape}
)
assert batch_out.type.shape == (2, None, None)
np.testing.assert_array_equal(
batch_out.eval({x: 3, y: 4}),
np.zeros((2, 3, 4)),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论