提交 a5d54c8e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Check for runtime broadcasting in Blockwise Ops

上级 893dc18c
...@@ -355,12 +355,30 @@ class Blockwise(Op): ...@@ -355,12 +355,30 @@ class Blockwise(Op):
self._gufunc = np.vectorize(core_func, signature=self.signature) self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc return self._gufunc
def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self._batch_ndim_from_outputs(node.outputs)
for dims_and_bcast in zip(
*[
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
gufunc = self._gufunc gufunc = self._gufunc
if gufunc is None: if gufunc is None:
gufunc = self._create_gufunc(node) gufunc = self._create_gufunc(node)
self._check_runtime_broadcast(node, inputs)
res = gufunc(*inputs) res = gufunc(*inputs)
if not isinstance(res, tuple): if not isinstance(res, tuple):
res = (res,) res = (res,)
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
from pytensor import config from pytensor import config, function
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
...@@ -38,6 +38,56 @@ def test_vectorize_blockwise(): ...@@ -38,6 +38,56 @@ def test_vectorize_blockwise():
assert new_vect_node.inputs[0] is tns4 assert new_vect_node.inputs[0] is tns4
def check_blockwise_runtime_broadcasting(mode):
a = tensor("a", shape=(None, 3, 5))
b = tensor("b", shape=(None, 5, 3))
out = a @ b
fn = function([a, b], out, mode=mode)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
for valid_test_values in [
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
batch_dim = valid_test_values[0].shape[0]
np.testing.assert_allclose(
fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0)
)
for invalid_test_values in [
(
np.ones((1, 3, 5)).astype(config.floatX),
np.ones((2, 5, 3)).astype(config.floatX),
),
(
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((1, 5, 3)).astype(config.floatX),
),
]:
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
fn(*invalid_test_values)
invalid_test_values = (
np.ones((2, 3, 5)).astype(config.floatX),
np.ones((3, 5, 3)).astype(config.floatX),
)
# Error message is backend specific
with pytest.raises(ValueError):
fn(*invalid_test_values)
@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN"))
def test_runtime_broadcast(mode):
check_blockwise_runtime_broadcasting(mode)
class TestOp(Op): class TestOp(Op):
def make_node(self, *inputs): def make_node(self, *inputs):
return Apply(self, inputs, [i.type() for i in inputs]) return Apply(self, inputs, [i.type() for i in inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论