提交 99f646ce authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Remove dependency on theano.Mode in theano.configdefaults

上级 92155d4d
......@@ -121,3 +121,20 @@ class TestConfigTypes:
with pytest.raises(ValueError, match="Invalid value"):
cp._apply("notadevice")
assert str(cp) == "None (cpu, opencl*, cuda*) "
def test_mode_apply():
from theano import configdefaults
assert configdefaults.filter_mode("DebugMode") == "DebugMode"
with pytest.raises(ValueError, match="Expected one of"):
configdefaults.filter_mode("not_a_mode")
# test with theano.Mode instance
import theano.compile.mode
assert (
configdefaults.filter_mode(theano.compile.mode.FAST_COMPILE)
== theano.compile.mode.FAST_COMPILE
)
......@@ -465,7 +465,7 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode.
def filter_mode(val):
if val in [
str_options = [
"Mode",
"DebugMode",
"FAST_RUN",
......@@ -473,18 +473,20 @@ def filter_mode(val):
"FAST_COMPILE",
"DEBUG_MODE",
"JAX",
]:
]
if val in str_options:
return val
# This can be executed before Theano is completly imported, so
# theano.Mode is not always available.
elif hasattr(theano, "Mode") and isinstance(val, theano.Mode):
return val
else:
raise ValueError(
"Expected one of those string 'Mode', 'DebugMode',"
" 'FAST_RUN', 'NanGuardMode', 'FAST_COMPILE',"
" 'DEBUG_MODE' or an instance of Mode."
)
# Instead of isinstance(val, theano.Mode),
# we can inspect the __mro__ of the object!
for type_ in type(val).__mro__:
if "theano.compile.mode.Mode" in str(type_):
return val
raise ValueError(
f"Expected one of {str_options}, or an instance of theano.Mode. "
f"Instead got: {val}."
)
AddConfigVar(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论