提交 99f646ce authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Remove dependency on theano.Mode in theano.configdefaults

上级 92155d4d
...@@ -121,3 +121,20 @@ class TestConfigTypes: ...@@ -121,3 +121,20 @@ class TestConfigTypes:
with pytest.raises(ValueError, match="Invalid value"): with pytest.raises(ValueError, match="Invalid value"):
cp._apply("notadevice") cp._apply("notadevice")
assert str(cp) == "None (cpu, opencl*, cuda*) " assert str(cp) == "None (cpu, opencl*, cuda*) "
def test_mode_apply():
from theano import configdefaults
assert configdefaults.filter_mode("DebugMode") == "DebugMode"
with pytest.raises(ValueError, match="Expected one of"):
configdefaults.filter_mode("not_a_mode")
# test with theano.Mode instance
import theano.compile.mode
assert (
configdefaults.filter_mode(theano.compile.mode.FAST_COMPILE)
== theano.compile.mode.FAST_COMPILE
)
...@@ -465,7 +465,7 @@ AddConfigVar( ...@@ -465,7 +465,7 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding # Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode. # new modes, since it is the default mode.
def filter_mode(val): def filter_mode(val):
if val in [ str_options = [
"Mode", "Mode",
"DebugMode", "DebugMode",
"FAST_RUN", "FAST_RUN",
...@@ -473,17 +473,19 @@ def filter_mode(val): ...@@ -473,17 +473,19 @@ def filter_mode(val):
"FAST_COMPILE", "FAST_COMPILE",
"DEBUG_MODE", "DEBUG_MODE",
"JAX", "JAX",
]: ]
if val in str_options:
return val return val
# This can be executed before Theano is completly imported, so # This can be executed before Theano is completly imported, so
# theano.Mode is not always available. # theano.Mode is not always available.
elif hasattr(theano, "Mode") and isinstance(val, theano.Mode): # Instead of isinstance(val, theano.Mode),
# we can inspect the __mro__ of the object!
for type_ in type(val).__mro__:
if "theano.compile.mode.Mode" in str(type_):
return val return val
else:
raise ValueError( raise ValueError(
"Expected one of those string 'Mode', 'DebugMode'," f"Expected one of {str_options}, or an instance of theano.Mode. "
" 'FAST_RUN', 'NanGuardMode', 'FAST_COMPILE'," f"Instead got: {val}."
" 'DEBUG_MODE' or an instance of Mode."
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论