提交 ebd59050 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct imports from theano.gof.utils in theano.link.c.basic

上级 2be119c0
...@@ -12,9 +12,10 @@ from io import StringIO ...@@ -12,9 +12,10 @@ from io import StringIO
import numpy as np import numpy as np
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import graph, utils from theano.gof import graph
from theano.gof.callcache import CallCache from theano.gof.callcache import CallCache
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
from theano.gof.utils import MethodNotDefined, difference, uniq
from theano.link.basic import Container, Linker, LocalLinker, PerformLinker from theano.link.basic import Container, Linker, LocalLinker, PerformLinker
from theano.link.c.cmodule import ( from theano.link.c.cmodule import (
METH_VARARGS, METH_VARARGS,
...@@ -670,7 +671,7 @@ class CLinker(Linker): ...@@ -670,7 +671,7 @@ class CLinker(Linker):
variable.type.c_literal(variable.data) variable.type.c_literal(variable.data)
self.consts.append(variable) self.consts.append(variable)
self.orphans.remove(variable) self.orphans.remove(variable)
except (utils.MethodNotDefined, NotImplementedError): except (MethodNotDefined, NotImplementedError):
pass pass
self.temps = list( self.temps = list(
...@@ -851,7 +852,7 @@ class CLinker(Linker): ...@@ -851,7 +852,7 @@ class CLinker(Linker):
# type-specific support code # type-specific support code
try: try:
c_support_code_apply.append(op.c_support_code_apply(node, name)) c_support_code_apply.append(op.c_support_code_apply(node, name))
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
else: else:
# The following will be executed if the "try" block succeeds # The following will be executed if the "try" block succeeds
...@@ -861,7 +862,7 @@ class CLinker(Linker): ...@@ -861,7 +862,7 @@ class CLinker(Linker):
try: try:
c_init_code_apply.append(op.c_init_code_apply(node, name)) c_init_code_apply.append(op.c_init_code_apply(node, name))
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
else: else:
assert isinstance(c_init_code_apply[-1], str), ( assert isinstance(c_init_code_apply[-1], str), (
...@@ -873,7 +874,7 @@ class CLinker(Linker): ...@@ -873,7 +874,7 @@ class CLinker(Linker):
assert isinstance(struct_init, str), ( assert isinstance(struct_init, str), (
str(node.op) + " didn't return a string for c_init_code_struct" str(node.op) + " didn't return a string for c_init_code_struct"
) )
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
try: try:
...@@ -881,7 +882,7 @@ class CLinker(Linker): ...@@ -881,7 +882,7 @@ class CLinker(Linker):
assert isinstance(struct_support, str), ( assert isinstance(struct_support, str), (
str(node.op) + " didn't return a string for c_support_code_struct" str(node.op) + " didn't return a string for c_support_code_struct"
) )
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
try: try:
...@@ -889,13 +890,13 @@ class CLinker(Linker): ...@@ -889,13 +890,13 @@ class CLinker(Linker):
assert isinstance(struct_cleanup, str), ( assert isinstance(struct_cleanup, str), (
str(node.op) + " didn't return a string for c_cleanup_code_struct" str(node.op) + " didn't return a string for c_cleanup_code_struct"
) )
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
# emit c_code # emit c_code
try: try:
behavior = op.c_code(node, name, isyms, osyms, sub) behavior = op.c_code(node, name, isyms, osyms, sub)
except utils.MethodNotDefined: except MethodNotDefined:
raise NotImplementedError(f"{op} cannot produce C code") raise NotImplementedError(f"{op} cannot produce C code")
assert isinstance( assert isinstance(
behavior, str behavior, str
...@@ -907,7 +908,7 @@ class CLinker(Linker): ...@@ -907,7 +908,7 @@ class CLinker(Linker):
try: try:
cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub) cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub)
except utils.MethodNotDefined: except MethodNotDefined:
cleanup = "" cleanup = ""
_logger.info(f"compiling un-versioned Apply {node}") _logger.info(f"compiling un-versioned Apply {node}")
...@@ -930,7 +931,7 @@ class CLinker(Linker): ...@@ -930,7 +931,7 @@ class CLinker(Linker):
args = [] args = []
args += [ args += [
f"storage_{symbol[variable]}" f"storage_{symbol[variable]}"
for variable in utils.uniq(self.inputs + self.outputs + self.orphans) for variable in uniq(self.inputs + self.outputs + self.orphans)
] ]
# <<<<HASH_PLACEHOLDER>>>> will be replaced by a hash of the whole # <<<<HASH_PLACEHOLDER>>>> will be replaced by a hash of the whole
...@@ -991,7 +992,7 @@ class CLinker(Linker): ...@@ -991,7 +992,7 @@ class CLinker(Linker):
ret.extend(support_code) ret.extend(support_code)
else: else:
ret.append(support_code) ret.append(support_code)
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
return ret return ret
...@@ -1029,10 +1030,10 @@ class CLinker(Linker): ...@@ -1029,10 +1030,10 @@ class CLinker(Linker):
ret += x.c_compile_args(c_compiler) ret += x.c_compile_args(c_compiler)
except TypeError: except TypeError:
ret += x.c_compile_args() ret += x.c_compile_args()
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
ret = utils.uniq(ret) # to remove duplicate ret = uniq(ret) # to remove duplicate
# The args set by the compiler include the user flags. We do not want # The args set by the compiler include the user flags. We do not want
# to reorder them # to reorder them
ret += c_compiler.compile_args() ret += c_compiler.compile_args()
...@@ -1047,7 +1048,7 @@ class CLinker(Linker): ...@@ -1047,7 +1048,7 @@ class CLinker(Linker):
ret.remove(i) ret.remove(i)
except ValueError: except ValueError:
pass # in case the value is not there pass # in case the value is not there
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
return ret return ret
...@@ -1067,9 +1068,9 @@ class CLinker(Linker): ...@@ -1067,9 +1068,9 @@ class CLinker(Linker):
ret += x.c_headers(c_compiler) ret += x.c_headers(c_compiler)
except TypeError: except TypeError:
ret += x.c_headers() ret += x.c_headers()
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
return utils.uniq(ret) return uniq(ret)
def init_code(self): def init_code(self):
""" """
...@@ -1083,9 +1084,9 @@ class CLinker(Linker): ...@@ -1083,9 +1084,9 @@ class CLinker(Linker):
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: try:
ret += x.c_init_code() ret += x.c_init_code()
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
return utils.uniq(ret) return uniq(ret)
def c_compiler(self): def c_compiler(self):
c_compiler = None c_compiler = None
...@@ -1124,10 +1125,10 @@ class CLinker(Linker): ...@@ -1124,10 +1125,10 @@ class CLinker(Linker):
ret += x.c_header_dirs(c_compiler) ret += x.c_header_dirs(c_compiler)
except TypeError: except TypeError:
ret += x.c_header_dirs() ret += x.c_header_dirs()
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
# filter out empty strings/None # filter out empty strings/None
return [r for r in utils.uniq(ret) if r] return [r for r in uniq(ret) if r]
def libraries(self): def libraries(self):
""" """
...@@ -1145,9 +1146,9 @@ class CLinker(Linker): ...@@ -1145,9 +1146,9 @@ class CLinker(Linker):
ret += x.c_libraries(c_compiler) ret += x.c_libraries(c_compiler)
except TypeError: except TypeError:
ret += x.c_libraries() ret += x.c_libraries()
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
return utils.uniq(ret) return uniq(ret)
def lib_dirs(self): def lib_dirs(self):
""" """
...@@ -1165,10 +1166,10 @@ class CLinker(Linker): ...@@ -1165,10 +1166,10 @@ class CLinker(Linker):
ret += x.c_lib_dirs(c_compiler) ret += x.c_lib_dirs(c_compiler)
except TypeError: except TypeError:
ret += x.c_lib_dirs() ret += x.c_lib_dirs()
except utils.MethodNotDefined: except MethodNotDefined:
pass pass
# filter out empty strings/None # filter out empty strings/None
return [r for r in utils.uniq(ret) if r] return [r for r in uniq(ret) if r]
def __compile__( def __compile__(
self, input_storage=None, output_storage=None, storage_map=None, keep_lock=False self, input_storage=None, output_storage=None, storage_map=None, keep_lock=False
...@@ -1962,7 +1963,7 @@ class OpWiseCLinker(LocalLinker): ...@@ -1962,7 +1963,7 @@ class OpWiseCLinker(LocalLinker):
if no_recycling is True: if no_recycling is True:
no_recycling = list(storage_map.values()) no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage) no_recycling = difference(no_recycling, input_storage)
else: else:
no_recycling = [ no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs storage_map[r] for r in no_recycling if r not in fgraph.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论