提交 1ff084bf authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make imports from theano.gof.utils and theano.gof.toolbox explicit

上级 c65bf858
...@@ -144,19 +144,18 @@ from functools import reduce ...@@ -144,19 +144,18 @@ from functools import reduce
from theano import config from theano import config
from theano.gof import ( from theano.gof import (
utils,
Op, Op,
view_roots, view_roots,
local_optimizer, local_optimizer,
Optimizer, Optimizer,
InconsistencyError, InconsistencyError,
toolbox,
SequenceDB, SequenceDB,
EquilibriumOptimizer, EquilibriumOptimizer,
Apply, Apply,
ReplacementDidntRemovedError, ReplacementDidntRemovedError,
) )
from theano.gof.utils import TestValueError from theano.gof.toolbox import ReplaceValidate
from theano.gof.utils import TestValueError, MethodNotDefined, memoize
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
...@@ -417,7 +416,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): ...@@ -417,7 +416,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
) )
@utils.memoize @memoize
def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
"""Extract list of compilation flags from a string. """Extract list of compilation flags from a string.
...@@ -1084,7 +1083,7 @@ class Gemm(GemmRelated): ...@@ -1084,7 +1083,7 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp _z, _a, _x, _y, _b = inp
(_zout,) = out (_zout,) = out
if node.inputs[0].type.dtype.startswith("complex"): if node.inputs[0].type.dtype.startswith("complex"):
raise utils.MethodNotDefined("%s.c_code" % self.__class__.__name__) raise MethodNotDefined("%s.c_code" % self.__class__.__name__)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
...@@ -1463,7 +1462,7 @@ class GemmOptimizer(Optimizer): ...@@ -1463,7 +1462,7 @@ class GemmOptimizer(Optimizer):
self.warned = False self.warned = False
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
did_something = True did_something = True
...@@ -1670,7 +1669,7 @@ class Dot22(GemmRelated): ...@@ -1670,7 +1669,7 @@ class Dot22(GemmRelated):
_x, _y = inp _x, _y = inp
(_zout,) = out (_zout,) = out
if node.inputs[0].type.dtype.startswith("complex"): if node.inputs[0].type.dtype.startswith("complex"):
raise utils.MethodNotDefined("%s.c_code" % self.__class__.__name__) raise MethodNotDefined("%s.c_code" % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_zout,), sub) return super(Dot22, self).c_code(node, name, (_x, _y), (_zout,), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
...@@ -1940,7 +1939,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1940,7 +1939,7 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp _x, _y, _a = inp
(_zout,) = out (_zout,) = out
if node.inputs[0].type.dtype.startswith("complex"): if node.inputs[0].type.dtype.startswith("complex"):
raise utils.MethodNotDefined("%s.c_code" % self.__class__.__name__) raise MethodNotDefined("%s.c_code" % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout,), sub) return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout,), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论