提交 6cd90ee9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix RaiseAndCheck C implementation with tensor conditions.

For performance, the Op now always converts the inputs to boolean scalars. Also do not constant-fold if it would raise.
上级 b9fc4f8e
......@@ -10,10 +10,11 @@ from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.configdefaults import config
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.raise_op import CheckAndRaise
if config.floatX == "float64":
......@@ -73,11 +74,14 @@ def jax_funcify_IfElse(op, **kwargs):
return ifelse
@jax_funcify.register(Assert)
@jax_funcify.register(CheckAndRaise)
def jax_funcify_CheckAndRaise(op, **kwargs):
def jax_funcify_CheckAndRaise(op, node, **kwargs):
conds = node.inputs[1:]
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
raise op.exc_type(op.msg)
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""",
f"""Skipping {op} Op (assertion: {op.msg}) as JAX tracing would remove it.""",
stacklevel=2,
)
......
......@@ -22,6 +22,7 @@ from pytensor.tensor.basic import (
Eye,
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
)
......@@ -79,6 +80,14 @@ def pytorch_funcify_CastingOp(op, node, **kwargs):
return type_cast
@pytorch_funcify.register(ScalarFromTensor)
def pytorch_funcify_ScalarFromTensor(op, node, **kwargs):
def scalar_from_tensor(x):
return x[()]
return scalar_from_tensor
@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
......@@ -86,7 +95,7 @@ def pytorch_funcify_CheckAndRaise(op, **kwargs):
def assert_fn(x, *conditions):
for cond in conditions:
if not cond.item():
if not cond:
raise error(msg)
return x
......
......@@ -2,15 +2,13 @@
from textwrap import indent
import numpy as np
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.basic import ScalarType, as_scalar
from pytensor.tensor.type import DenseTensorType
......@@ -56,18 +54,6 @@ class CheckAndRaise(COp):
msg = self.msg
return f"{name}{{raises={exc_name}, msg='{msg}'}}"
def __eq__(self, other):
if type(self) is not type(other):
return False
if self.msg == other.msg and self.exc_type == other.exc_type:
return True
return False
def __hash__(self):
return hash((self.msg, self.exc_type))
def make_node(self, value: Variable, *conds: Variable):
"""
......@@ -84,12 +70,10 @@ class CheckAndRaise(COp):
if not isinstance(value, Variable):
value = pt.as_tensor_variable(value)
conds = [
pt.as_tensor_variable(c) if not isinstance(c, Variable) else c
for c in conds
]
assert all(c.type.ndim == 0 for c in conds)
conds = [as_scalar(c) for c in conds]
for i, cond in enumerate(conds):
if cond.dtype != "bool":
conds[i] = cond.astype("bool")
return Apply(
self,
......@@ -101,7 +85,7 @@ class CheckAndRaise(COp):
(out,) = outputs
val, *conds = inputs
out[0] = val
if not np.all(conds):
if not all(conds):
raise self.exc_type(self.msg)
def grad(self, input, output_gradients):
......@@ -117,38 +101,20 @@ class CheckAndRaise(COp):
)
value_name, *cond_names = inames
out_name = onames[0]
check = []
fail_code = props["fail"]
param_struct_name = props["params"]
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")
for idx, cond_name in enumerate(cond_names):
if isinstance(node.inputs[0].type, DenseTensorType):
check.append(
f"""
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
else:
check.append(
f"""
if({cond_name} == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
check = "\n".join(check)
all_conds = " && ".join(cond_names)
check = f"""
if(!({all_conds})) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
if isinstance(node.inputs[0].type, DenseTensorType):
res = f"""
......@@ -162,14 +128,19 @@ class CheckAndRaise(COp):
{check}
{out_name} = {value_name};
"""
return res
return "\n".join((check, res))
def c_code_cache_version(self):
return (1, 1)
return (2,)
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
def do_constant_folding(self, fgraph, node):
# Only constant-fold if the Assert does not fail
return all((isinstance(c, Constant) and bool(c.data)) for c in node.inputs[1:])
class Assert(CheckAndRaise):
"""Implements assertion in a computational graph.
......
......@@ -732,20 +732,15 @@ def is_an_upcast(type1, type2):
@register_useless
@register_specialize
@node_rewriter(None)
@node_rewriter([CheckAndRaise])
def local_remove_useless_assert(fgraph, node):
if not isinstance(node.op, CheckAndRaise):
return False
new_conds = []
n_conds = len(node.inputs[1:])
for c in node.inputs[1:]:
try:
const = get_scalar_constant_value(c)
if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it
# is not caught?
if not const:
new_conds.append(c)
except NotScalarConstantError:
new_conds.append(c)
......
......@@ -487,8 +487,8 @@ class TestUselessCheckAndRaise:
def test_local_remove_useless_2(self):
"""Remove `CheckAndRaise` conditions that are always true."""
x = scalar()
y = scalar()
x = scalar("x")
y = ps.bool("y")
fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False)
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
......@@ -497,8 +497,8 @@ class TestUselessCheckAndRaise:
def test_local_remove_useless_3(self):
"""Don't remove `CheckAndRaise` conditions that are always false."""
x = scalar()
y = scalar()
x = scalar("x")
y = ps.bool("y")
fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False)
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
......@@ -1559,7 +1559,7 @@ def test_local_merge_alloc():
output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w)
f = function([m, x, y, y2, z, w], output, mode=rewrite_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert len(topo) == 4
assert isinstance(topo[-2].op, Assert)
assert isinstance(topo[-1].op, Alloc)
o = f(0.0, 1, 2, 2, 3, 4)
......@@ -1616,7 +1616,7 @@ def test_local_useless_alloc():
useless_alloc.rewrite(g)
topo = g.toposort()
assert len(topo) == 3
assert len(topo) == 4
assert isinstance(topo[-2].op, Assert)
assert isinstance(topo[-1].op, Alloc)
......
......@@ -932,7 +932,7 @@ class TestFusion:
),
(fx,),
(fxv,),
4,
5,
(np.zeros_like(fxv),),
("float32",),
),
......
......@@ -8,6 +8,7 @@ from pytensor import function
from pytensor import tensor as pt
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Constant, applys_between, equal_computations
from pytensor.npy_2_compat import old_np_unique
from pytensor.raise_op import Assert
......@@ -1252,11 +1253,17 @@ def test_broadcast_shape_symbolic_one_symbolic():
]
res_shape = broadcast_shape(*index_shapes, arrays_are_shapes=True)
from pytensor.graph.rewriting.utils import rewrite_graph
res_shape = rewrite_graph(res_shape)
assert res_shape[0].data == 1
assert res_shape[1].data == 1
with pytest.raises(AssertionError, match="Could not broadcast dimensions"):
# broadcast_shape doesn't treat int_div as a constant 1
res_shape[2].eval()
res_shape = broadcast_shape(
*index_shapes, arrays_are_shapes=True, allow_runtime_broadcast=True
)
res_shape = rewrite_graph(res_shape)
assert res_shape[0].data == 1
assert res_shape[1].data == 1
assert res_shape[2].data == 3
......
......@@ -82,19 +82,26 @@ def test_CheckAndRaise_basic_c(linker):
with pytest.raises(CustomException, match=exc_msg):
y_fn(0)
assert y_fn(1) == 1.0
x = pt.vector()
x_val = np.array([1.0], dtype=pytensor.config.floatX)
y = check_and_raise(x, conds)
y_fn = pytensor.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN))
y_fn = pytensor.function([conds, x], y, mode=Mode(linker, OPT_FAST_RUN))
with pytest.raises(CustomException, match=exc_msg):
y_fn(0, x_val)
assert np.array_equal(y_fn(1, x_val), x_val)
x_val = np.array([1.0], dtype=pytensor.config.floatX)
y_fn = pytensor.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN))
# The shape doesn't depend on y so the Assert is dropped from the graph
assert np.array_equal(y_fn(0, x_val), x_val)
y = check_and_raise(x, pt.as_tensor(0))
y_grad = pytensor.grad(y.sum(), [x])
y_grad = pytensor.grad(y.sum(), x)
y_fn = pytensor.function([x], y_grad, mode=Mode(linker, OPT_FAST_RUN))
assert np.array_equal(y_fn(x_val), [x_val])
# The gradient doesn't depend on y, just it's shape so the Assert is dropped from the graph
assert np.array_equal(y_fn(x_val), x_val)
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论