提交 175b67b9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lazy pydot import

上级 3409264d
......@@ -12,13 +12,7 @@ import pytensor
from pytensor.compile import Function, builders
from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.printing import pydot_imported, pydot_imported_msg
try:
from pytensor.printing import pd
except ImportError:
pass
from pytensor.printing import _try_pydot_import
class PyDotFormatter:
......@@ -41,8 +35,7 @@ class PyDotFormatter:
def __init__(self, compact=True):
"""Construct PyDotFormatter object."""
if not pydot_imported:
raise ImportError("Failed to import pydot. " + pydot_imported_msg)
_try_pydot_import()
self.compact = compact
self.node_colors = {
......@@ -115,6 +108,8 @@ class PyDotFormatter:
pydot.Dot
Pydot graph of `fct`
"""
pd = _try_pydot_import()
if graph is None:
graph = pd.Dot()
......@@ -356,6 +351,8 @@ def type_to_str(t):
def dict_to_pdnode(d):
"""Create pydot node from dict."""
pd = _try_pydot_import()
e = dict()
for k, v in d.items():
if v is not None:
......
......@@ -26,39 +26,6 @@ from pytensor.graph.utils import Scratchpad
IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd
if hasattr(pd, "find_graphviz"):
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot can't find graphviz"
else:
pd.Dot.create(pd.Dot())
pydot_imported = True
except ImportError:
# tests should not fail on optional dependency
pydot_imported_msg = (
"Install the python package pydot or pydot-ng. Install graphviz."
)
except Exception as e:
pydot_imported_msg = "An error happened while importing/trying pydot: "
pydot_imported_msg += str(e.args)
_logger = logging.getLogger("pytensor.printing")
VALID_ASSOC = {"left", "right", "either"}
......@@ -1196,6 +1163,48 @@ default_colorCodes = {
}
def _try_pydot_import():
pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd
if hasattr(pd, "find_graphviz"):
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot can't find graphviz"
else:
pd.Dot.create(pd.Dot())
pydot_imported = True
except ImportError:
# tests should not fail on optional dependency
pydot_imported_msg = (
"Install the python package pydot or pydot-ng. Install graphviz."
)
except Exception as e:
pydot_imported_msg = "An error happened while importing/trying pydot: "
pydot_imported_msg += str(e.args)
if not pydot_imported:
raise ImportError(
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)
return pd
def pydotprint(
fct,
outfile: Path | str | None = None,
......@@ -1288,6 +1297,8 @@ def pydotprint(
scan separately after the top level debugprint output.
"""
pd = _try_pydot_import()
from pytensor.scan.op import Scan
if colorCodes is None:
......@@ -1320,12 +1331,6 @@ def pydotprint(
outputs = fct.outputs
topo = fct.toposort()
fgraph = fct
if not pydot_imported:
raise RuntimeError(
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)
g = pd.Dot()
......
......@@ -9,12 +9,14 @@ import pytensor.d3viz as d3v
from pytensor import compile
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.printing import pydot_imported, pydot_imported_msg
from pytensor.printing import _try_pydot_import
from tests.d3viz import models
if not pydot_imported:
pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True)
try:
_try_pydot_import()
except Exception as e:
pytest.skip(f"pydot not available: {e!s}", allow_module_level=True)
class TestD3Viz:
......
......@@ -3,11 +3,13 @@ import pytest
from pytensor import config, function
from pytensor.d3viz.formatting import PyDotFormatter
from pytensor.printing import pydot_imported, pydot_imported_msg
from pytensor.printing import _try_pydot_import
if not pydot_imported:
pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True)
try:
_try_pydot_import()
except Exception as e:
pytest.skip(f"pydot not available: {e!s}", allow_module_level=True)
from tests.d3viz import models
......
......@@ -5,7 +5,7 @@ import pytensor
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.printing import debugprint, pydot_imported, pydotprint
from pytensor.printing import _try_pydot_import, debugprint, pydotprint
from pytensor.tensor.type import dvector, iscalar, scalar, vector
......@@ -686,6 +686,13 @@ def test_debugprint_compiled_fn():
assert truth.strip() == out.strip()
try:
_try_pydot_import()
pydot_imported = True
except Exception:
pydot_imported = False
@pytest.mark.skipif(not pydot_imported, reason="pydot not available")
def test_pydotprint():
def f_pow2(x_tm1):
......
......@@ -17,13 +17,13 @@ from pytensor.printing import (
PatternPrinter,
PPrinter,
Print,
_try_pydot_import,
char_from_number,
debugprint,
default_printer,
get_node_by_id,
min_informative_str,
pp,
pydot_imported,
pydotprint,
)
from pytensor.tensor import as_tensor_variable
......@@ -31,6 +31,13 @@ from pytensor.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
try:
_try_pydot_import()
pydot_imported = True
except Exception:
pydot_imported = False
@pytest.mark.parametrize(
"number,s",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论