Unverified 提交 d4e8f736 authored 作者: Luca Citi's avatar Luca Citi 提交者: GitHub

Use stricter numerical tolerance in rewrites and allow casting in `PatternNodeRewriter` (#1526)

* Implemented allow_cast in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. Example: `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as `sigmoid(-x)` (where x is an fmatrix) because the type would change. This commit allows an automatic cast to be added so the expression is rewritten as `cast(sigmoid(-x), "float64")`. Relevant tests added. * Added test cases for which issue #1497 fails * Changed PatternNodeRewriter::transform to allow types that do not contain dtype like MyType in the tests * Address #1497 by changing instances of np.isclose to a function isclose, which uses 10 ULPs by default * Addressed failed tests (with older python/numpy versions) * Addressed feedback by ricardoV94 * Test PatternNodeRewriter doesn't support multi-output nodes in pattern But it's fine if they're just root inputs --------- Co-authored-by: 's avatarLuca Citi <lciti@ieee.org> Co-authored-by: 's avatarRicardo Vieira <ricardo.vieira1994@gmail.com>
上级 0c138495
......@@ -1550,6 +1550,7 @@ class PatternNodeRewriter(NodeRewriter):
tracks=(),
get_nodes=None,
values_eq_approx=None,
allow_cast=True,
):
"""
......@@ -1572,6 +1573,10 @@ class PatternNodeRewriter(NodeRewriter):
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this rewrite.
values_eq_approx
TODO
allow_cast
Automatically cast the output of the rewrite whenever new and old types differ
Notes
-----
......@@ -1586,6 +1591,7 @@ class PatternNodeRewriter(NodeRewriter):
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx
self.allow_cast = allow_cast
if isinstance(in_pattern, list | tuple):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
......@@ -1630,6 +1636,10 @@ class PatternNodeRewriter(NodeRewriter):
if node.op != self.op:
return False
if len(node.outputs) != 1:
# PatternNodeRewriter doesn't support replacing multi-output nodes
return False
s = unify(self.in_pattern, node.out)
if s is False:
......@@ -1652,19 +1662,20 @@ class PatternNodeRewriter(NodeRewriter):
):
return False
if ret.owner:
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
# Type doesn't match
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
self.allow_cast
and isinstance(old_out.type, pytensor.tensor.TensorType)
and isinstance(ret.type, pytensor.tensor.TensorType)
):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
if not node.outputs[0].type.is_super(ret.type):
# Try to cast tensors
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
# Still doesn't match
return False
return [ret]
......
......@@ -2440,7 +2440,7 @@ def local_log1p(fgraph, node):
log_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1):
if scalars and isclose(np.sum(scalars), 1):
if nonconsts:
ninp = variadic_add(*nonconsts)
if ninp.dtype != log_arg.type.dtype:
......@@ -3045,6 +3045,21 @@ def local_grad_log_erfc_neg(fgraph, node):
return [ret]
def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
"""
Returns
-------
bool
True iff x is a constant close to ref (by default 10 ULPs).
"""
x = np.asarray(x)
if np.issubdtype(x.dtype, np.floating):
atol = atol + num_ulps * np.abs(np.spacing(x.dtype.type(ref)))
return np.allclose(x, ref, rtol=rtol, atol=atol)
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
......@@ -3063,7 +3078,7 @@ def _is_1(expr):
"""
try:
v = get_underlying_scalar_constant_value(expr)
return np.isclose(v, 1)
return isclose(v, 1)
except NotScalarConstantError:
return False
......@@ -3124,7 +3139,7 @@ def is_1pexp(t, only_process_constants=True):
scal_sum = scalars[0]
for s in scalars[1:]:
scal_sum = scal_sum + s
if np.allclose(scal_sum, 1):
if isclose(scal_sum, 1):
return False, maybe_exp.owner.inputs[0]
return None
......@@ -3224,7 +3239,7 @@ def is_neg(var):
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.isclose(constant, -1)
is_minus_1 = isclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
......@@ -3632,7 +3647,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
if scalars_ and isclose(np.sum(scalars_), 1):
out = [
alloc_like(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
......
......@@ -41,6 +41,7 @@ from tests.graph.utils import (
op_y,
op_z,
)
from tests.unittest_tools import assert_equal_computations
class AssertNoChanges(Feature):
......@@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern):
assert e.type.is_super(fg.outputs[0].type)
def test_patternsub_different_output_lengths():
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
ps = PatternNodeRewriter(
(op1, "x"),
("x"),
def test_patternsub_multi_output_nodes():
# Test that PatternNodeRewriter won't attempt to replace multi-output nodes
multiple_op_ps = PatternNodeRewriter(
(op_multiple_outputs, "x"),
"x",
name="ps",
)
rewriter = in2out(ps)
single_op_ps = PatternNodeRewriter(
(op_y, "x"),
"x",
name="ps",
)
rewriter = in2out(multiple_op_ps, single_op_ps)
x = MyVariable("x")
e1, e2 = op_multiple_outputs(x)
o = op1(e1)
o1, o2 = op_y(e1), op_y(e2)
fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False)
rewriter.rewrite(fgraph)
# This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched
assert_equal_computations(fgraph.outputs, [e2, e1])
fgraph = FunctionGraph(inputs=[x], outputs=[o])
fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False)
rewriter.rewrite(fgraph)
assert fgraph.outputs[0].owner.op == op1
# Having a variable that comes out of a multi-output node should be fine
assert_equal_computations(fgraph.outputs, [e2, e1])
class TestSequentialNodeRewriter:
......
......@@ -107,6 +107,9 @@ class MyOpCastType2(MyOp):
class MyOpMultipleOutputs(MyOp):
def __init__(self, name, dmap=None, x=None):
super().__init__(name=name, dmap=dmap, x=x, n_outs=2)
def make_node(self, input):
outputs = [input.type(), input.type()]
return Apply(self, [input], outputs)
......
......@@ -50,6 +50,7 @@ from pytensor.tensor.math import (
bitwise_and,
bitwise_or,
bitwise_xor,
cast,
conj,
cosh,
deg2rad,
......@@ -124,6 +125,7 @@ from pytensor.tensor.type import (
dvector,
fmatrices,
fmatrix,
fscalar,
ftensor4,
fvector,
imatrices,
......@@ -4069,25 +4071,36 @@ class TestSigmoidRewrites:
def test_local_1msigmoid(self):
m = self.get_mode(excluding=["fusion", "inplace"])
x = fmatrix()
x = fscalar()
xd = dscalar()
# Test `exp_over_1_plus_exp`
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
# FIXME: PatternNodeRewriter does not copy stack trace
# (see https://github.com/Theano/Theano/issues/4581)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
# Test `inv_1_plus_exp`
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
# Test float constant
f = pytensor.function(
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
)
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
for out, expected in [
(np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)),
(np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")),
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
(np.array(1.0, "float64") - sigmoid(xd), sigmoid(-xd)),
(np.sum(1 / np.array([2, 3, 6], "float32")) - sigmoid(x), sigmoid(-x)),
(np.sum(1 / np.array([2, 3, 6], "float64")) - sigmoid(xd), sigmoid(-xd)),
(np.float32(1 - 9e-6) - sigmoid(x), np.float32(1 - 9e-6) - sigmoid(x)),
(np.float64(1 - 1e-9) - sigmoid(xd), np.float64(1 - 1e-9) - sigmoid(xd)),
]:
rewritten = rewrite_graph(
out, include=["canonicalize", "specialize", "stabilize"]
)
utt.assert_equal_computations([rewritten], [expected], original=out)
def test_local_sigm_times_exp(self):
"""
......@@ -4235,7 +4248,8 @@ class TestSoftplusRewrites:
f(np.random.random((54, 11)).astype(config.floatX))
# Test close to 1
out = log(1.000001 - sigmoid(x))
x_dtype = np.dtype(x.dtype).type
out = log(np.nextafter(x_dtype(1), x_dtype(2)) - sigmoid(x))
f = pytensor.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
......
......@@ -11,6 +11,7 @@ import pytensor
from pytensor.compile.debugmode import str_diagnostic
from pytensor.configdefaults import config
from pytensor.gradient import verify_grad as orig_verify_grad
from pytensor.graph.basic import equal_computations
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import add as pt_add
......@@ -279,6 +280,41 @@ def assert_allclose(expected, value, rtol=None, atol=None):
raise WrongValue(expected, value, rtol, atol)
def assert_equal_computations(rewritten, expected, *args, original=None, **kwargs):
"""
Assert that `rewritten` computes the same as `expected`.
Parameters
----------
rewritten
The expression after the rewrite pass.
expected
The reference expression to compare against.
*args, **kwargs
Extra arguments forwarded to equal_computations.
original : optional
If given, will be printed in the error message.
"""
__tracebackhide__ = True # Hide traceback for py.test
ok = equal_computations(rewritten, expected, *args, **kwargs)
if not ok:
parts = []
def _dprint(expr):
return pytensor.dprint(expr, print_type=True, file="str")
if original is not None:
parts.append(f"\nOriginal:\n{_dprint(original)}")
parts.append(f"\nRewritten:\n{_dprint(rewritten)}")
parts.append(f"\nExpected:\n{_dprint(expected)}")
raise AssertionError("equal_computations failed\n" + "".join(parts))
return True
class AttemptManyTimes:
"""Decorator for unit tests that forces a unit test to be attempted
multiple times. The test needs to pass a certain number of times for it to
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论