提交 67ff0800 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Remove redundant erf(c) rewrites

上级 6176f28c
......@@ -6,7 +6,7 @@ import itertools
import logging
import operator
import warnings
from functools import reduce
from functools import partial, reduce
import numpy as np
......@@ -23,6 +23,7 @@ from aesara.graph.opt import (
in2out,
local_optimizer,
)
from aesara.graph.opt_utils import get_clients_at_depth
from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import (
Alloc,
......@@ -2596,26 +2597,8 @@ def local_greedy_distributor(fgraph, node):
return [rval]
def get_clients(fgraph, node):
"""
Used by erf/erfc opt to track less frequent op.
"""
return [c for c, i in fgraph.clients[node.outputs[0]] if c != "output"]
def get_clients2(fgraph, node):
"""
Used by erf/erfc opt to track less frequent op.
"""
l = []
for c, i in fgraph.clients[node.outputs[0]]:
if c != "output":
for var in c.outputs:
l.extend([cc for cc, ii in fgraph.clients[var] if cc != "output"])
return l
get_clients_at_depth1 = partial(get_clients_at_depth, depth=1)
get_clients_at_depth2 = partial(get_clients_at_depth, depth=2)
# 1+erf(x)=>erfc(-x)
local_one_plus_erf = PatternSub(
......@@ -2624,61 +2607,56 @@ local_one_plus_erf = PatternSub(
allow_multiple_clients=True,
name="local_one_plus_erf",
tracks=[erf],
get_nodes=get_clients,
get_nodes=get_clients_at_depth1,
)
register_canonicalize(local_one_plus_erf)
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)
# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternSub(
(sub, 1, (erf, "x")),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_minus_erf",
tracks=[erf],
get_nodes=get_clients_at_depth1,
)
register_canonicalize(local_one_minus_erf)
register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf)
local_one_minus_erf2 = PatternSub(
(add, 1, (mul, -1, (erf, "x"))),
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_minus_erf2",
tracks=[erf],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)
# 1+(-erf(x))=>erfc(x) This is a different graph then the previous as
# the canonicalize don't work completely
local_one_plus_neg_erf = PatternSub(
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_plus_neg_erf",
tracks=[erf],
get_nodes=get_clients2,
)
register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf)
# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument.
# (-1)+erf(x) => -erfc(x)
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
# convert those to the matched pattern
local_erf_minus_one = PatternSub(
(add, -1, (erf, "x")),
(neg, (erfc, "x")),
allow_multiple_clients=True,
name="local_erf_minus_one",
tracks=[erf],
get_nodes=get_clients,
get_nodes=get_clients_at_depth1,
)
register_canonicalize(local_erf_minus_one)
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)
# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternSub(
(sub, 1, (erfc, "x")),
......@@ -2686,7 +2664,7 @@ local_one_minus_erfc = PatternSub(
allow_multiple_clients=True,
name="local_one_minus_erfc",
tracks=[erfc],
get_nodes=get_clients,
get_nodes=get_clients_at_depth1,
)
register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc)
......@@ -2698,39 +2676,12 @@ local_one_minus_erfc2 = PatternSub(
allow_multiple_clients=True,
name="local_one_minus_erfc2",
tracks=[erfc],
get_nodes=get_clients2,
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)
local_one_minus_erfc3 = PatternSub(
(add, 1, (mul, -1, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_minus_erfc3",
tracks=[erfc],
get_nodes=get_clients2,
)
register_canonicalize(local_one_minus_erfc3)
register_stabilize(local_one_minus_erfc3)
register_specialize(local_one_minus_erfc3)
# 1+(-erfc(x)) => erf(x) This is a different graph then the previous as
# the canonicalize don't work completely
local_one_add_neg_erfc = PatternSub(
(add, 1, (neg, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_add_neg_erfc",
tracks=[erfc],
get_nodes=get_clients2,
)
register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc)
# (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = PatternSub(
(add, -1, (erfc, (neg, "x"))),
......@@ -2738,25 +2689,12 @@ local_erf_neg_minus_one = PatternSub(
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
tracks=[erfc],
get_nodes=get_clients,
get_nodes=get_clients_at_depth1,
)
register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one)
# (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = PatternSub(
(add, -1, (erfc, (mul, -1, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one2",
tracks=[erfc],
get_nodes=get_clients,
)
register_canonicalize(local_erf_neg_minus_one2)
register_stabilize(local_erf_neg_minus_one2)
register_specialize(local_erf_neg_minus_one2)
@register_stabilize
@register_specialize
......
......@@ -2742,20 +2742,19 @@ class TestLocalErf:
self.mode = (
get_default_mode()
.including("canonicalize", "fast_run")
.excluding("gpu", "fusion")
.excluding("gpu", "fusion", "inplace")
)
self.mode._optimizer.position_cutoff = 1.50001
def test_local_one_plus_erf(self):
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
x = vector()
f = function([x], 1 + erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc]
assert [n.op for n in f.maker.fgraph.toposort()] == [neg, erfc]
f(val)
f = function([x], erf(x) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc]
assert [n.op for n in f.maker.fgraph.toposort()] == [neg, erfc]
f(val)
f = function([x], erf(x) + 2, mode=self.mode)
......@@ -2780,6 +2779,9 @@ class TestLocalErf:
f = function([x], (-erf(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f = function([x], (-1.0 * erf(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f = function([x], 2 - erf(x), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
......@@ -2794,14 +2796,14 @@ class TestLocalErf:
x = vector()
f = function([x], erf(x) - 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, mul]
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, neg]
f(val)
f = function([x], erf(x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, mul]
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, neg]
f = function([x], -1 + erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, mul]
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, neg]
f = function([x], erf(x) - 2, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -2821,12 +2823,10 @@ class TestLocalErfc:
def setup_method(self):
self.mode_fusion = (
get_default_mode()
.including("canonicalize")
.including("fast_run")
.excluding("gpu")
.including("canonicalize", "fast_run")
.excluding("gpu", "inplace")
)
self.mode = self.mode_fusion.excluding("fusion")
self.mode._optimizer.position_cutoff = 1.50001
def test_local_one_minus_erfc(self):
# test opt: 1-erfc(x) => erf(x) and -erfc(x)+1 => erf(x)
......@@ -2841,6 +2841,9 @@ class TestLocalErfc:
f = function([x], (-erfc(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], (-1.0 * erfc(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], 2 - erfc(x), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
......@@ -2863,6 +2866,9 @@ class TestLocalErfc:
f = function([x], erfc(-x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], erfc(-1.0 * x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
@pytest.mark.xfail()
def test_local_log_erfc(self):
val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论