提交 1ff084bf authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make imports from theano.gof.utils and theano.gof.toolbox explicit

上级 c65bf858
......@@ -144,19 +144,18 @@ from functools import reduce
from theano import config
from theano.gof import (
utils,
Op,
view_roots,
local_optimizer,
Optimizer,
InconsistencyError,
toolbox,
SequenceDB,
EquilibriumOptimizer,
Apply,
ReplacementDidntRemovedError,
)
from theano.gof.utils import TestValueError
from theano.gof.toolbox import ReplaceValidate
from theano.gof.utils import TestValueError, MethodNotDefined, memoize
from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint
......@@ -417,7 +416,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
)
@utils.memoize
@memoize
def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
"""Extract list of compilation flags from a string.
......@@ -1084,7 +1083,7 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp
(_zout,) = out
if node.inputs[0].type.dtype.startswith("complex"):
raise utils.MethodNotDefined("%s.c_code" % self.__class__.__name__)
raise MethodNotDefined("%s.c_code" % self.__class__.__name__)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
......@@ -1463,7 +1462,7 @@ class GemmOptimizer(Optimizer):
self.warned = False
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
did_something = True
......@@ -1670,7 +1669,7 @@ class Dot22(GemmRelated):
_x, _y = inp
(_zout,) = out
if node.inputs[0].type.dtype.startswith("complex"):
raise utils.MethodNotDefined("%s.c_code" % self.__class__.__name__)
raise MethodNotDefined("%s.c_code" % self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_zout,), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
......@@ -1940,7 +1939,7 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp
(_zout,) = out
if node.inputs[0].type.dtype.startswith("complex"):
raise utils.MethodNotDefined("%s.c_code" % self.__class__.__name__)
raise MethodNotDefined("%s.c_code" % self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout,), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论