提交 9c445e0a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct theano.gof imports in theano.d3viz.formatting

上级 22c67a69
...@@ -8,8 +8,10 @@ from functools import reduce ...@@ -8,8 +8,10 @@ from functools import reduce
import numpy as np import numpy as np
import theano import theano
from theano import gof
from theano.compile import Function, builders from theano.compile import Function, builders
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import inputs as graph_inputs
from theano.printing import pydot_imported, pydot_imported_msg from theano.printing import pydot_imported, pydot_imported_msg
...@@ -125,16 +127,16 @@ class PyDotFormatter: ...@@ -125,16 +127,16 @@ class PyDotFormatter:
if isinstance(fct, Function): if isinstance(fct, Function):
profile = getattr(fct, "profile", None) profile = getattr(fct, "profile", None)
fgraph = fct.maker.fgraph fgraph = fct.maker.fgraph
elif isinstance(fct, gof.FunctionGraph): elif isinstance(fct, FunctionGraph):
fgraph = fct fgraph = fct
else: else:
if isinstance(fct, gof.Variable): if isinstance(fct, Variable):
fct = [fct] fct = [fct]
elif isinstance(fct, gof.Apply): elif isinstance(fct, Apply):
fct = fct.outputs fct = fct.outputs
assert isinstance(fct, (list, tuple)) assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct) assert all(isinstance(v, Variable) for v in fct)
fgraph = gof.FunctionGraph(inputs=gof.graph.inputs(fct), outputs=fct) fgraph = FunctionGraph(inputs=graph_inputs(fct), outputs=fct)
outputs = fgraph.outputs outputs = fgraph.outputs
topo = fgraph.toposort() topo = fgraph.toposort()
...@@ -172,7 +174,7 @@ class PyDotFormatter: ...@@ -172,7 +174,7 @@ class PyDotFormatter:
"label": var_label(var), "label": var_label(var),
"node_type": "input", "node_type": "input",
} }
if isinstance(var, gof.Constant): if isinstance(var, Constant):
vparams["node_type"] = "constant_input" vparams["node_type"] = "constant_input"
elif isinstance(var, theano.tensor.sharedvar.TensorSharedVariable): elif isinstance(var, theano.tensor.sharedvar.TensorSharedVariable):
vparams["node_type"] = "shared_input" vparams["node_type"] = "shared_input"
...@@ -260,7 +262,7 @@ def var_label(var, precision=3): ...@@ -260,7 +262,7 @@ def var_label(var, precision=3):
"""Return label of variable node.""" """Return label of variable node."""
if var.name is not None: if var.name is not None:
return var.name return var.name
elif isinstance(var, gof.Constant): elif isinstance(var, Constant):
h = np.asarray(var.data) h = np.asarray(var.data)
is_const = False is_const = False
if h.ndim == 0: if h.ndim == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论