提交 f4de2fd2 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Maxim Kochurov

Run pydot/graphviz tests in CI

Closes #151
上级 f92e1096
......@@ -50,4 +50,5 @@ dependencies:
- typing_extensions
# optional
- cython
- graphviz
- pydot
......@@ -9,7 +9,7 @@ import pytensor.d3viz as d3v
from pytensor import compile
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.d3viz.formatting import pydot_imported, pydot_imported_msg
from pytensor.printing import pydot_imported, pydot_imported_msg
from tests.d3viz import models
......
......@@ -2,7 +2,8 @@ import numpy as np
import pytest
from pytensor import config, function
from pytensor.d3viz.formatting import PyDotFormatter, pydot_imported, pydot_imported_msg
from pytensor.d3viz.formatting import PyDotFormatter
from pytensor.printing import pydot_imported, pydot_imported_msg
if not pydot_imported:
......@@ -21,21 +22,23 @@ class TestPyDotFormatter:
nc = dict(zip(a, b))
return nc
def test_mlp(self):
@pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"])
def test_mlp(self, mode):
m = models.Mlp()
f = function(m.inputs, m.outputs)
f = function(m.inputs, m.outputs, mode=mode)
pdf = PyDotFormatter()
graph = pdf(f)
expected = 11
if config.mode == "FAST_COMPILE":
expected = 12
if mode == "FAST_RUN":
expected = 13
elif mode == "FAST_COMPILE":
expected = 14
assert len(graph.get_nodes()) == expected
nc = self.node_counts(graph)
if config.mode == "FAST_COMPILE":
assert nc["apply"] == 6
else:
assert nc["apply"] == 5
if mode == "FAST_RUN":
assert nc["apply"] == 7
elif mode == "FAST_COMPILE":
assert nc["apply"] == 8
assert nc["output"] == 1
def test_ofg(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论