提交 a5d54c8e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Check for runtime broadcasting in Blockwise Ops

上级 893dc18c
......@@ -355,12 +355,30 @@ class Blockwise(Op):
self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc
def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self._batch_ndim_from_outputs(node.outputs)
for dims_and_bcast in zip(
*[
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
def perform(self, node, inputs, output_storage):
gufunc = self._gufunc
if gufunc is None:
gufunc = self._create_gufunc(node)
self._check_runtime_broadcast(node, inputs)
res = gufunc(*inputs)
if not isinstance(res, tuple):
res = (res,)
......
......@@ -5,7 +5,7 @@ import numpy as np
import pytest
import pytensor
from pytensor import config
from pytensor import config, function
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
......@@ -38,6 +38,56 @@ def test_vectorize_blockwise():
assert new_vect_node.inputs[0] is tns4
def check_blockwise_runtime_broadcasting(mode):
a = tensor("a", shape=(None, 3, 5))
b = tensor("b", shape=(None, 5, 3))
out = a @ b
fn = function([a, b], out, mode=mode)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
for valid_test_values in [
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
batch_dim = valid_test_values[0].shape[0]
np.testing.assert_allclose(
fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0)
)
for invalid_test_values in [
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
fn(*invalid_test_values)
invalid_test_values = (
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((3, 5, 3)).astype(config.floatX),
)
# Error message is backend specific
with pytest.raises(ValueError):
fn(*invalid_test_values)
@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN"))
def test_runtime_broadcast(mode):
check_blockwise_runtime_broadcasting(mode)
class TestOp(Op):
def make_node(self, *inputs):
return Apply(self, inputs, [i.type() for i in inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论