提交 6b189ee3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove `specialize_device` database

上级 e10b9515
......@@ -248,11 +248,6 @@ optdb.register("specialize", EquilibriumDB(), "fast_run", "fast_compile", positi
# misc special cases for speed that break canonicalization
optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)
# misc special cases for speed that are dependent on the device.
optdb.register(
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
) # must be after gpu stuff at 48.5
# especially constant merge
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)
......
......@@ -205,25 +205,6 @@ def register_uncanonicalize(
return node_rewriter
def register_specialize_device(
node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize_device(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register
else:
name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
compile.optdb["specialize_device"].register(
name, node_rewriter, "fast_run", *tags, **kwargs
)
return node_rewriter
@register_canonicalize
@register_specialize
@node_rewriter([TensorFromScalar])
......
......@@ -88,7 +88,6 @@ from pytensor.tensor.rewriting.basic import (
local_fill_sink,
register_canonicalize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
......@@ -2078,12 +2077,14 @@ def local_pow_specialize(fgraph, node):
return False
@register_specialize_device
@register_specialize
@node_rewriter([at_pow])
def local_pow_specialize_device(fgraph, node):
"""
This rewrite is not the same on all device. We do it only on cpu here.
def local_pow_to_nested_squaring(fgraph, node):
"""Convert a large power exponent to multiple squaring operations.
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
"""
if node.op == at_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
......
......@@ -1672,12 +1672,12 @@ def test_local_pow_specialize():
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
def test_local_pow_specialize_device_more_aggressive_on_cpu():
def test_local_pow_to_nested_squaring():
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
mode = get_mode(mode)
mode = mode.excluding("fusion").excluding("gpu")
mode = mode.excluding("fusion")
v = vector()
val = np.arange(10, dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论