提交 2be119c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct imports from theano.gof in theano.scalar.basic

上级 706ef19b
......@@ -20,11 +20,12 @@ from textwrap import dedent
import numpy as np
import theano
from theano import gof, printing
from theano import printing
from theano.configdefaults import config
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes
from theano.gof.op import COp
from theano.gof.opt import MergeOptimizer
from theano.gof.type import Type
from theano.gof.utils import (
MetaObject,
......@@ -118,7 +119,7 @@ get_scalar_type.cache = {}
def as_scalar(x, name=None):
from ..tensor import TensorType, scalar_from_tensor
if isinstance(x, gof.Apply):
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
......@@ -4123,7 +4124,7 @@ class Composite(ScalarOp):
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
gof.MergeOptimizer().optimize(fgraph)
MergeOptimizer().optimize(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError(
......@@ -4148,7 +4149,7 @@ class Composite(ScalarOp):
[isinstance(var.owner.op, Composite) for var in outputs]
):
# No inner Composite
inputs, outputs = gof.graph.clone(inputs, outputs)
inputs, outputs = clone(inputs, outputs)
else:
# Inner Composite that we need to flatten
assert len(outputs) == 1
......@@ -4170,7 +4171,7 @@ class Composite(ScalarOp):
inputs, outputs = res[0], res2[1]
# Next assert comment just for speed
# assert not any([isinstance(node.op, Composite) for node in
# theano.gof.graph.ops(inputs, outputs)])
# ops(inputs, outputs)])
self.inputs = copy(inputs)
self.outputs = copy(outputs)
......@@ -4188,7 +4189,7 @@ class Composite(ScalarOp):
if impl == "py":
self.init_py_impls() # self._impls
if impl not in self.prepare_node_called:
for n in theano.gof.graph.list_of_nodes(self.inputs, self.outputs):
for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论