提交 2be119c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct imports from theano.gof in theano.scalar.basic

上级 706ef19b
...@@ -20,11 +20,12 @@ from textwrap import dedent ...@@ -20,11 +20,12 @@ from textwrap import dedent
import numpy as np import numpy as np
import theano import theano
from theano import gof, printing from theano import printing
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.opt import MergeOptimizer
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.utils import ( from theano.gof.utils import (
MetaObject, MetaObject,
...@@ -118,7 +119,7 @@ get_scalar_type.cache = {} ...@@ -118,7 +119,7 @@ get_scalar_type.cache = {}
def as_scalar(x, name=None): def as_scalar(x, name=None):
from ..tensor import TensorType, scalar_from_tensor from ..tensor import TensorType, scalar_from_tensor
if isinstance(x, gof.Apply): if isinstance(x, Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
raise ValueError( raise ValueError(
"It is ambiguous which output of a multi-output" "It is ambiguous which output of a multi-output"
...@@ -4123,7 +4124,7 @@ class Composite(ScalarOp): ...@@ -4123,7 +4124,7 @@ class Composite(ScalarOp):
# the fgraph to be set to the variable as we need to pickle # the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work. # them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs) fgraph = FunctionGraph(self.inputs, self.outputs)
gof.MergeOptimizer().optimize(fgraph) MergeOptimizer().optimize(fgraph)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
raise ValueError( raise ValueError(
...@@ -4148,7 +4149,7 @@ class Composite(ScalarOp): ...@@ -4148,7 +4149,7 @@ class Composite(ScalarOp):
[isinstance(var.owner.op, Composite) for var in outputs] [isinstance(var.owner.op, Composite) for var in outputs]
): ):
# No inner Composite # No inner Composite
inputs, outputs = gof.graph.clone(inputs, outputs) inputs, outputs = clone(inputs, outputs)
else: else:
# Inner Composite that we need to flatten # Inner Composite that we need to flatten
assert len(outputs) == 1 assert len(outputs) == 1
...@@ -4170,7 +4171,7 @@ class Composite(ScalarOp): ...@@ -4170,7 +4171,7 @@ class Composite(ScalarOp):
inputs, outputs = res[0], res2[1] inputs, outputs = res[0], res2[1]
# Next assert comment just for speed # Next assert comment just for speed
# assert not any([isinstance(node.op, Composite) for node in # assert not any([isinstance(node.op, Composite) for node in
# theano.gof.graph.ops(inputs, outputs)]) # ops(inputs, outputs)])
self.inputs = copy(inputs) self.inputs = copy(inputs)
self.outputs = copy(outputs) self.outputs = copy(outputs)
...@@ -4188,7 +4189,7 @@ class Composite(ScalarOp): ...@@ -4188,7 +4189,7 @@ class Composite(ScalarOp):
if impl == "py": if impl == "py":
self.init_py_impls() # self._impls self.init_py_impls() # self._impls
if impl not in self.prepare_node_called: if impl not in self.prepare_node_called:
for n in theano.gof.graph.list_of_nodes(self.inputs, self.outputs): for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl) n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl) self.prepare_node_called.add(impl)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论