提交 175b67b9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lazy pydot import

上级 3409264d
...@@ -12,13 +12,7 @@ import pytensor ...@@ -12,13 +12,7 @@ import pytensor
from pytensor.compile import Function, builders from pytensor.compile import Function, builders
from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.printing import pydot_imported, pydot_imported_msg from pytensor.printing import _try_pydot_import
try:
from pytensor.printing import pd
except ImportError:
pass
class PyDotFormatter: class PyDotFormatter:
...@@ -41,8 +35,7 @@ class PyDotFormatter: ...@@ -41,8 +35,7 @@ class PyDotFormatter:
def __init__(self, compact=True): def __init__(self, compact=True):
"""Construct PyDotFormatter object.""" """Construct PyDotFormatter object."""
if not pydot_imported: _try_pydot_import()
raise ImportError("Failed to import pydot. " + pydot_imported_msg)
self.compact = compact self.compact = compact
self.node_colors = { self.node_colors = {
...@@ -115,6 +108,8 @@ class PyDotFormatter: ...@@ -115,6 +108,8 @@ class PyDotFormatter:
pydot.Dot pydot.Dot
Pydot graph of `fct` Pydot graph of `fct`
""" """
pd = _try_pydot_import()
if graph is None: if graph is None:
graph = pd.Dot() graph = pd.Dot()
...@@ -356,6 +351,8 @@ def type_to_str(t): ...@@ -356,6 +351,8 @@ def type_to_str(t):
def dict_to_pdnode(d): def dict_to_pdnode(d):
"""Create pydot node from dict.""" """Create pydot node from dict."""
pd = _try_pydot_import()
e = dict() e = dict()
for k, v in d.items(): for k, v in d.items():
if v is not None: if v is not None:
......
...@@ -26,39 +26,6 @@ from pytensor.graph.utils import Scratchpad ...@@ -26,39 +26,6 @@ from pytensor.graph.utils import Scratchpad
IDTypesType = Literal["id", "int", "CHAR", "auto", ""] IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd
if hasattr(pd, "find_graphviz"):
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot can't find graphviz"
else:
pd.Dot.create(pd.Dot())
pydot_imported = True
except ImportError:
# tests should not fail on optional dependency
pydot_imported_msg = (
"Install the python package pydot or pydot-ng. Install graphviz."
)
except Exception as e:
pydot_imported_msg = "An error happened while importing/trying pydot: "
pydot_imported_msg += str(e.args)
_logger = logging.getLogger("pytensor.printing") _logger = logging.getLogger("pytensor.printing")
VALID_ASSOC = {"left", "right", "either"} VALID_ASSOC = {"left", "right", "either"}
...@@ -1196,6 +1163,48 @@ default_colorCodes = { ...@@ -1196,6 +1163,48 @@ default_colorCodes = {
} }
def _try_pydot_import():
pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd
if hasattr(pd, "find_graphviz"):
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot can't find graphviz"
else:
pd.Dot.create(pd.Dot())
pydot_imported = True
except ImportError:
# tests should not fail on optional dependency
pydot_imported_msg = (
"Install the python package pydot or pydot-ng. Install graphviz."
)
except Exception as e:
pydot_imported_msg = "An error happened while importing/trying pydot: "
pydot_imported_msg += str(e.args)
if not pydot_imported:
raise ImportError(
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)
return pd
def pydotprint( def pydotprint(
fct, fct,
outfile: Path | str | None = None, outfile: Path | str | None = None,
...@@ -1288,6 +1297,8 @@ def pydotprint( ...@@ -1288,6 +1297,8 @@ def pydotprint(
scan separately after the top level debugprint output. scan separately after the top level debugprint output.
""" """
pd = _try_pydot_import()
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
if colorCodes is None: if colorCodes is None:
...@@ -1320,12 +1331,6 @@ def pydotprint( ...@@ -1320,12 +1331,6 @@ def pydotprint(
outputs = fct.outputs outputs = fct.outputs
topo = fct.toposort() topo = fct.toposort()
fgraph = fct fgraph = fct
if not pydot_imported:
raise RuntimeError(
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)
g = pd.Dot() g = pd.Dot()
......
...@@ -9,12 +9,14 @@ import pytensor.d3viz as d3v ...@@ -9,12 +9,14 @@ import pytensor.d3viz as d3v
from pytensor import compile from pytensor import compile
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.printing import pydot_imported, pydot_imported_msg from pytensor.printing import _try_pydot_import
from tests.d3viz import models from tests.d3viz import models
if not pydot_imported: try:
pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True) _try_pydot_import()
except Exception as e:
pytest.skip(f"pydot not available: {e!s}", allow_module_level=True)
class TestD3Viz: class TestD3Viz:
......
...@@ -3,11 +3,13 @@ import pytest ...@@ -3,11 +3,13 @@ import pytest
from pytensor import config, function from pytensor import config, function
from pytensor.d3viz.formatting import PyDotFormatter from pytensor.d3viz.formatting import PyDotFormatter
from pytensor.printing import pydot_imported, pydot_imported_msg from pytensor.printing import _try_pydot_import
if not pydot_imported: try:
pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True) _try_pydot_import()
except Exception as e:
pytest.skip(f"pydot not available: {e!s}", allow_module_level=True)
from tests.d3viz import models from tests.d3viz import models
......
...@@ -5,7 +5,7 @@ import pytensor ...@@ -5,7 +5,7 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.printing import debugprint, pydot_imported, pydotprint from pytensor.printing import _try_pydot_import, debugprint, pydotprint
from pytensor.tensor.type import dvector, iscalar, scalar, vector from pytensor.tensor.type import dvector, iscalar, scalar, vector
...@@ -686,6 +686,13 @@ def test_debugprint_compiled_fn(): ...@@ -686,6 +686,13 @@ def test_debugprint_compiled_fn():
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
try:
_try_pydot_import()
pydot_imported = True
except Exception:
pydot_imported = False
@pytest.mark.skipif(not pydot_imported, reason="pydot not available") @pytest.mark.skipif(not pydot_imported, reason="pydot not available")
def test_pydotprint(): def test_pydotprint():
def f_pow2(x_tm1): def f_pow2(x_tm1):
......
...@@ -17,13 +17,13 @@ from pytensor.printing import ( ...@@ -17,13 +17,13 @@ from pytensor.printing import (
PatternPrinter, PatternPrinter,
PPrinter, PPrinter,
Print, Print,
_try_pydot_import,
char_from_number, char_from_number,
debugprint, debugprint,
default_printer, default_printer,
get_node_by_id, get_node_by_id,
min_informative_str, min_informative_str,
pp, pp,
pydot_imported,
pydotprint, pydotprint,
) )
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
...@@ -31,6 +31,13 @@ from pytensor.tensor.type import dmatrix, dvector, matrix ...@@ -31,6 +31,13 @@ from pytensor.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
try:
_try_pydot_import()
pydot_imported = True
except Exception:
pydot_imported = False
@pytest.mark.parametrize( @pytest.mark.parametrize(
"number,s", "number,s",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论