提交 2cf617d5 authored 作者: Ricardo's avatar Ricardo 提交者: Rémi Louf

Remove redundant erf(c) rewrites

上级 bbf937c3
......@@ -2575,8 +2575,6 @@ register_canonicalize(local_one_plus_erf)
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)
# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternNodeRewriter(
(sub, 1, (erf, "x")),
......@@ -2590,21 +2588,9 @@ register_canonicalize(local_one_minus_erf)
register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf)
local_one_minus_erf2 = PatternNodeRewriter(
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_minus_erf2",
tracks=[erf],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)
# (-1)+erf(x) => -erfc(x)
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
# convert those to the matched pattern
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the `local_add_mul`
# canonicalize will convert those to the matched pattern
local_erf_minus_one = PatternNodeRewriter(
(add, -1, (erf, "x")),
(neg, (erfc, "x")),
......@@ -2617,8 +2603,6 @@ register_canonicalize(local_erf_minus_one)
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)
# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternNodeRewriter(
(sub, 1, (erfc, "x")),
......@@ -2632,21 +2616,9 @@ register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)
local_one_minus_erfc2 = PatternNodeRewriter(
(add, 1, (neg, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_minus_erfc2",
tracks=[erfc],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)
# (-1)+erfc(-x)=>erf(x)
# erfc(-x)-1=>erf(x)
local_erf_neg_minus_one = PatternNodeRewriter(
(add, -1, (erfc, (neg, "x"))),
(sub, (erfc, (neg, "x")), 1),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论