提交 bdf98ca9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Add get_target_language function and remove tests.compile.test_modes

上级 b12cd96a
......@@ -7,6 +7,8 @@ import logging
import warnings
from typing import Optional, Tuple, Union
from typing_extensions import Literal
from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config
from pytensor.graph.destroyhandler import DestroyHandler
......@@ -530,3 +532,26 @@ def register_mode(name, mode):
if name in predefined_modes:
raise ValueError(f"Mode name already taken: {name}")
predefined_modes[name] = mode
def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], ...]:
"""Get the compilation target language."""
if mode is None:
mode = get_default_mode()
linker = mode.linker
if isinstance(linker, NumbaLinker):
return ("numba",)
if isinstance(linker, JAXLinker):
return ("jax",)
if isinstance(linker, PerformLinker):
return ("py",)
if isinstance(linker, CLinker):
return ("c",)
if isinstance(linker, (VMLinker, OpWiseCLinker)):
return ("c", "py") if config.cxx else ("py",)
raise Exception(f"Unsupported Linker: {linker}")
import copy
import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import AddFeatureOptimizer, Mode
from pytensor.compile.mode import (
AddFeatureOptimizer,
Mode,
get_default_mode,
get_target_language,
)
from pytensor.configdefaults import config
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
from pytensor.link.basic import LocalLinker
from pytensor.tensor.math import dot, tanh
from pytensor.tensor.type import matrix
from pytensor.tensor.type import matrix, vector
def test_Mode_basic():
......@@ -48,3 +59,86 @@ def test_including():
new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
class TestBunchOfModes:
def test_modes(self):
# this is a quick test after the LazyLinker branch merge
# to check that all the current modes can still be used.
linker_classes_involved = []
predef_modes = ["FAST_COMPILE", "FAST_RUN", "DEBUG_MODE"]
# Linkers to use with regular Mode
if config.cxx:
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"]
else:
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"]
modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers]
for mode in modes:
x = matrix()
y = vector()
f = function([x, y], x + y, mode=mode)
# test that it runs something
f([[1, 2], [3, 4]], [5, 6])
linker_classes_involved.append(f.maker.mode.linker.__class__)
# print 'MODE:', mode, f.maker.mode.linker, 'stop'
# regression check:
# there should be
# - `VMLinker`
# - OpWiseCLinker (FAST_RUN)
# - PerformLinker (FAST_COMPILE)
# - DebugMode's Linker (DEBUG_MODE)
assert 4 == len(set(linker_classes_involved))
class TestOldModesProblem:
def test_modes(self):
# Then, build a mode with the same linker, and a modified optimizer
default_mode = get_default_mode()
modified_mode = default_mode.including("specialize")
# The following line used to fail, with Python 2.4, in July 2012,
# because an fgraph was associated to the default linker
copy.deepcopy(modified_mode)
# More straightforward test
linker = get_default_mode().linker
assert not hasattr(linker, "fgraph") or linker.fgraph is None
def test_get_target_language():
with config.change_flags(mode=Mode(linker="py")):
res = get_target_language()
assert res == ("py",)
res = get_target_language(Mode(linker="py"))
assert res == ("py",)
res = get_target_language(Mode(linker="c"))
assert res == ("c",)
res = get_target_language(Mode(linker="c|py"))
assert res == ("c", "py")
res = get_target_language(Mode(linker="vm"))
assert res == ("c", "py")
with config.change_flags(cxx=""):
res = get_target_language(Mode(linker="vm"))
assert res == ("py",)
res = get_target_language(Mode(linker="jax"))
assert res == ("jax",)
res = get_target_language(Mode(linker="numba"))
assert res == ("numba",)
class MyLinker(LocalLinker):
pass
test_mode = Mode(linker=MyLinker())
with pytest.raises(Exception):
get_target_language(test_mode)
"""
Test compilation modes
"""
import copy
from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor.type import matrix, vector
class TestBunchOfModes:
def test_modes(self):
# this is a quick test after the LazyLinker branch merge
# to check that all the current modes can still be used.
linker_classes_involved = []
predef_modes = ["FAST_COMPILE", "FAST_RUN", "DEBUG_MODE"]
# Linkers to use with regular Mode
if config.cxx:
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"]
else:
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"]
modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers]
for mode in modes:
x = matrix()
y = vector()
f = function([x, y], x + y, mode=mode)
# test that it runs something
f([[1, 2], [3, 4]], [5, 6])
linker_classes_involved.append(f.maker.mode.linker.__class__)
# print 'MODE:', mode, f.maker.mode.linker, 'stop'
# regression check:
# there should be
# - `VMLinker`
# - OpWiseCLinker (FAST_RUN)
# - PerformLinker (FAST_COMPILE)
# - DebugMode's Linker (DEBUG_MODE)
assert 4 == len(set(linker_classes_involved))
class TestOldModesProblem:
def test_modes(self):
# Then, build a mode with the same linker, and a modified optimizer
default_mode = get_default_mode()
modified_mode = default_mode.including("specialize")
# The following line used to fail, with Python 2.4, in July 2012,
# because an fgraph was associated to the default linker
copy.deepcopy(modified_mode)
# More straightforward test
linker = get_default_mode().linker
assert not hasattr(linker, "fgraph") or linker.fgraph is None
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论