提交 9c445e0a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct theano.gof imports in theano.d3viz.formatting

上级 22c67a69
......@@ -8,8 +8,10 @@ from functools import reduce
import numpy as np
import theano
from theano import gof
from theano.compile import Function, builders
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import inputs as graph_inputs
from theano.printing import pydot_imported, pydot_imported_msg
......@@ -125,16 +127,16 @@ class PyDotFormatter:
if isinstance(fct, Function):
profile = getattr(fct, "profile", None)
fgraph = fct.maker.fgraph
elif isinstance(fct, gof.FunctionGraph):
elif isinstance(fct, FunctionGraph):
fgraph = fct
else:
if isinstance(fct, gof.Variable):
if isinstance(fct, Variable):
fct = [fct]
elif isinstance(fct, gof.Apply):
elif isinstance(fct, Apply):
fct = fct.outputs
assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct)
fgraph = gof.FunctionGraph(inputs=gof.graph.inputs(fct), outputs=fct)
assert all(isinstance(v, Variable) for v in fct)
fgraph = FunctionGraph(inputs=graph_inputs(fct), outputs=fct)
outputs = fgraph.outputs
topo = fgraph.toposort()
......@@ -172,7 +174,7 @@ class PyDotFormatter:
"label": var_label(var),
"node_type": "input",
}
if isinstance(var, gof.Constant):
if isinstance(var, Constant):
vparams["node_type"] = "constant_input"
elif isinstance(var, theano.tensor.sharedvar.TensorSharedVariable):
vparams["node_type"] = "shared_input"
......@@ -260,7 +262,7 @@ def var_label(var, precision=3):
"""Return label of variable node."""
if var.name is not None:
return var.name
elif isinstance(var, gof.Constant):
elif isinstance(var, Constant):
h = np.asarray(var.data)
is_const = False
if h.ndim == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论