提交 f4de2fd2 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Maxim Kochurov

Run pydot/graphviz tests in CI

Closes #151
上级 f92e1096
...@@ -50,4 +50,5 @@ dependencies: ...@@ -50,4 +50,5 @@ dependencies:
- typing_extensions - typing_extensions
# optional # optional
- cython - cython
- graphviz
- pydot
...@@ -9,7 +9,7 @@ import pytensor.d3viz as d3v ...@@ -9,7 +9,7 @@ import pytensor.d3viz as d3v
from pytensor import compile from pytensor import compile
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.d3viz.formatting import pydot_imported, pydot_imported_msg from pytensor.printing import pydot_imported, pydot_imported_msg
from tests.d3viz import models from tests.d3viz import models
......
...@@ -2,7 +2,8 @@ import numpy as np ...@@ -2,7 +2,8 @@ import numpy as np
import pytest import pytest
from pytensor import config, function from pytensor import config, function
from pytensor.d3viz.formatting import PyDotFormatter, pydot_imported, pydot_imported_msg from pytensor.d3viz.formatting import PyDotFormatter
from pytensor.printing import pydot_imported, pydot_imported_msg
if not pydot_imported: if not pydot_imported:
...@@ -21,21 +22,23 @@ class TestPyDotFormatter: ...@@ -21,21 +22,23 @@ class TestPyDotFormatter:
nc = dict(zip(a, b)) nc = dict(zip(a, b))
return nc return nc
def test_mlp(self): @pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"])
def test_mlp(self, mode):
m = models.Mlp() m = models.Mlp()
f = function(m.inputs, m.outputs) f = function(m.inputs, m.outputs, mode=mode)
pdf = PyDotFormatter() pdf = PyDotFormatter()
graph = pdf(f) graph = pdf(f)
expected = 11 if mode == "FAST_RUN":
if config.mode == "FAST_COMPILE": expected = 13
expected = 12 elif mode == "FAST_COMPILE":
expected = 14
assert len(graph.get_nodes()) == expected assert len(graph.get_nodes()) == expected
nc = self.node_counts(graph) nc = self.node_counts(graph)
if config.mode == "FAST_COMPILE": if mode == "FAST_RUN":
assert nc["apply"] == 6 assert nc["apply"] == 7
else: elif mode == "FAST_COMPILE":
assert nc["apply"] == 5 assert nc["apply"] == 8
assert nc["output"] == 1 assert nc["output"] == 1
def test_ofg(self): def test_ofg(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论