提交 67ff0800 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Remove redundant erf(c) rewrites

上级 6176f28c
...@@ -6,7 +6,7 @@ import itertools ...@@ -6,7 +6,7 @@ import itertools
import logging import logging
import operator import operator
import warnings import warnings
from functools import reduce from functools import partial, reduce
import numpy as np import numpy as np
...@@ -23,6 +23,7 @@ from aesara.graph.opt import ( ...@@ -23,6 +23,7 @@ from aesara.graph.opt import (
in2out, in2out,
local_optimizer, local_optimizer,
) )
from aesara.graph.opt_utils import get_clients_at_depth
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
...@@ -2596,26 +2597,8 @@ def local_greedy_distributor(fgraph, node): ...@@ -2596,26 +2597,8 @@ def local_greedy_distributor(fgraph, node):
return [rval] return [rval]
def get_clients(fgraph, node): get_clients_at_depth1 = partial(get_clients_at_depth, depth=1)
""" get_clients_at_depth2 = partial(get_clients_at_depth, depth=2)
Used by erf/erfc opt to track less frequent op.
"""
return [c for c, i in fgraph.clients[node.outputs[0]] if c != "output"]
def get_clients2(fgraph, node):
"""
Used by erf/erfc opt to track less frequent op.
"""
l = []
for c, i in fgraph.clients[node.outputs[0]]:
if c != "output":
for var in c.outputs:
l.extend([cc for cc, ii in fgraph.clients[var] if cc != "output"])
return l
# 1+erf(x)=>erfc(-x) # 1+erf(x)=>erfc(-x)
local_one_plus_erf = PatternSub( local_one_plus_erf = PatternSub(
...@@ -2624,61 +2607,56 @@ local_one_plus_erf = PatternSub( ...@@ -2624,61 +2607,56 @@ local_one_plus_erf = PatternSub(
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_one_plus_erf", name="local_one_plus_erf",
tracks=[erf], tracks=[erf],
get_nodes=get_clients, get_nodes=get_clients_at_depth1,
) )
register_canonicalize(local_one_plus_erf) register_canonicalize(local_one_plus_erf)
register_stabilize(local_one_plus_erf) register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf) register_specialize(local_one_plus_erf)
# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x) # 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternSub( local_one_minus_erf = PatternSub(
(sub, 1, (erf, "x")), (sub, 1, (erf, "x")),
(erfc, "x"), (erfc, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_one_minus_erf", name="local_one_minus_erf",
tracks=[erf],
get_nodes=get_clients_at_depth1,
) )
register_canonicalize(local_one_minus_erf) register_canonicalize(local_one_minus_erf)
register_stabilize(local_one_minus_erf) register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf) register_specialize(local_one_minus_erf)
local_one_minus_erf2 = PatternSub( local_one_minus_erf2 = PatternSub(
(add, 1, (mul, -1, (erf, "x"))), (add, 1, (neg, (erf, "x"))),
(erfc, "x"), (erfc, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_one_minus_erf2", name="local_one_minus_erf2",
tracks=[erf],
get_nodes=get_clients_at_depth2,
) )
register_canonicalize(local_one_minus_erf2) register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2) register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2) register_specialize(local_one_minus_erf2)
# 1+(-erf(x))=>erfc(x) This is a different graph then the previous as # (-1)+erf(x) => -erfc(x)
# the canonicalize don't work completely # There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
local_one_plus_neg_erf = PatternSub( # convert those to the matched pattern
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_plus_neg_erf",
tracks=[erf],
get_nodes=get_clients2,
)
register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf)
# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument.
local_erf_minus_one = PatternSub( local_erf_minus_one = PatternSub(
(add, -1, (erf, "x")), (add, -1, (erf, "x")),
(neg, (erfc, "x")), (neg, (erfc, "x")),
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_erf_minus_one", name="local_erf_minus_one",
tracks=[erf], tracks=[erf],
get_nodes=get_clients, get_nodes=get_clients_at_depth1,
) )
register_canonicalize(local_erf_minus_one) register_canonicalize(local_erf_minus_one)
register_stabilize(local_erf_minus_one) register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one) register_specialize(local_erf_minus_one)
# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x) # 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternSub( local_one_minus_erfc = PatternSub(
(sub, 1, (erfc, "x")), (sub, 1, (erfc, "x")),
...@@ -2686,7 +2664,7 @@ local_one_minus_erfc = PatternSub( ...@@ -2686,7 +2664,7 @@ local_one_minus_erfc = PatternSub(
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_one_minus_erfc", name="local_one_minus_erfc",
tracks=[erfc], tracks=[erfc],
get_nodes=get_clients, get_nodes=get_clients_at_depth1,
) )
register_canonicalize(local_one_minus_erfc) register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc) register_stabilize(local_one_minus_erfc)
...@@ -2698,39 +2676,12 @@ local_one_minus_erfc2 = PatternSub( ...@@ -2698,39 +2676,12 @@ local_one_minus_erfc2 = PatternSub(
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_one_minus_erfc2", name="local_one_minus_erfc2",
tracks=[erfc], tracks=[erfc],
get_nodes=get_clients2, get_nodes=get_clients_at_depth2,
) )
register_canonicalize(local_one_minus_erfc2) register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2) register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2) register_specialize(local_one_minus_erfc2)
local_one_minus_erfc3 = PatternSub(
(add, 1, (mul, -1, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_minus_erfc3",
tracks=[erfc],
get_nodes=get_clients2,
)
register_canonicalize(local_one_minus_erfc3)
register_stabilize(local_one_minus_erfc3)
register_specialize(local_one_minus_erfc3)
# 1+(-erfc(x)) => erf(x) This is a different graph then the previous as
# the canonicalize don't work completely
local_one_add_neg_erfc = PatternSub(
(add, 1, (neg, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_add_neg_erfc",
tracks=[erfc],
get_nodes=get_clients2,
)
register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc)
# (-1)+erfc(-x)=>erf(x) # (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = PatternSub( local_erf_neg_minus_one = PatternSub(
(add, -1, (erfc, (neg, "x"))), (add, -1, (erfc, (neg, "x"))),
...@@ -2738,25 +2689,12 @@ local_erf_neg_minus_one = PatternSub( ...@@ -2738,25 +2689,12 @@ local_erf_neg_minus_one = PatternSub(
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_erf_neg_minus_one", name="local_erf_neg_minus_one",
tracks=[erfc], tracks=[erfc],
get_nodes=get_clients, get_nodes=get_clients_at_depth1,
) )
register_canonicalize(local_erf_neg_minus_one) register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one) register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one) register_specialize(local_erf_neg_minus_one)
# (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = PatternSub(
(add, -1, (erfc, (mul, -1, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one2",
tracks=[erfc],
get_nodes=get_clients,
)
register_canonicalize(local_erf_neg_minus_one2)
register_stabilize(local_erf_neg_minus_one2)
register_specialize(local_erf_neg_minus_one2)
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
...@@ -2742,20 +2742,19 @@ class TestLocalErf: ...@@ -2742,20 +2742,19 @@ class TestLocalErf:
self.mode = ( self.mode = (
get_default_mode() get_default_mode()
.including("canonicalize", "fast_run") .including("canonicalize", "fast_run")
.excluding("gpu", "fusion") .excluding("gpu", "fusion", "inplace")
) )
self.mode._optimizer.position_cutoff = 1.50001
def test_local_one_plus_erf(self): def test_local_one_plus_erf(self):
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX) val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
x = vector() x = vector()
f = function([x], 1 + erf(x), mode=self.mode) f = function([x], 1 + erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc] assert [n.op for n in f.maker.fgraph.toposort()] == [neg, erfc]
f(val) f(val)
f = function([x], erf(x) + 1, mode=self.mode) f = function([x], erf(x) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc] assert [n.op for n in f.maker.fgraph.toposort()] == [neg, erfc]
f(val) f(val)
f = function([x], erf(x) + 2, mode=self.mode) f = function([x], erf(x) + 2, mode=self.mode)
...@@ -2780,6 +2779,9 @@ class TestLocalErf: ...@@ -2780,6 +2779,9 @@ class TestLocalErf:
f = function([x], (-erf(x)) + 1, mode=self.mode) f = function([x], (-erf(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc] assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f = function([x], (-1.0 * erf(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f = function([x], 2 - erf(x), mode=self.mode) f = function([x], 2 - erf(x), mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
...@@ -2794,14 +2796,14 @@ class TestLocalErf: ...@@ -2794,14 +2796,14 @@ class TestLocalErf:
x = vector() x = vector()
f = function([x], erf(x) - 1, mode=self.mode) f = function([x], erf(x) - 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, mul] assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, neg]
f(val) f(val)
f = function([x], erf(x) + (-1), mode=self.mode) f = function([x], erf(x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, mul] assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, neg]
f = function([x], -1 + erf(x), mode=self.mode) f = function([x], -1 + erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, mul] assert [n.op for n in f.maker.fgraph.toposort()] == [erfc, neg]
f = function([x], erf(x) - 2, mode=self.mode) f = function([x], erf(x) - 2, mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -2821,12 +2823,10 @@ class TestLocalErfc: ...@@ -2821,12 +2823,10 @@ class TestLocalErfc:
def setup_method(self): def setup_method(self):
self.mode_fusion = ( self.mode_fusion = (
get_default_mode() get_default_mode()
.including("canonicalize") .including("canonicalize", "fast_run")
.including("fast_run") .excluding("gpu", "inplace")
.excluding("gpu")
) )
self.mode = self.mode_fusion.excluding("fusion") self.mode = self.mode_fusion.excluding("fusion")
self.mode._optimizer.position_cutoff = 1.50001
def test_local_one_minus_erfc(self): def test_local_one_minus_erfc(self):
# test opt: 1-erfc(x) => erf(x) and -erfc(x)+1 => erf(x) # test opt: 1-erfc(x) => erf(x) and -erfc(x)+1 => erf(x)
...@@ -2841,6 +2841,9 @@ class TestLocalErfc: ...@@ -2841,6 +2841,9 @@ class TestLocalErfc:
f = function([x], (-erfc(x)) + 1, mode=self.mode) f = function([x], (-erfc(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf] assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], (-1.0 * erfc(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], 2 - erfc(x), mode=self.mode) f = function([x], 2 - erfc(x), mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
...@@ -2863,6 +2866,9 @@ class TestLocalErfc: ...@@ -2863,6 +2866,9 @@ class TestLocalErfc:
f = function([x], erfc(-x) + (-1), mode=self.mode) f = function([x], erfc(-x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf] assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], erfc(-1.0 * x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
@pytest.mark.xfail() @pytest.mark.xfail()
def test_local_log_erfc(self): def test_local_log_erfc(self):
val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30] val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论