提交 6193b72e authored 作者: Frederic's avatar Frederic

pep8

上级 316ac334
...@@ -51,8 +51,8 @@ from theano import config ...@@ -51,8 +51,8 @@ from theano import config
# of cutils_ext. # of cutils_ext.
from theano.configparser import AddConfigVar, StrParam from theano.configparser import AddConfigVar, StrParam
AddConfigVar('gcc.cxxflags', AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc", "Extra compiler flags for gcc",
StrParam("")) StrParam(""))
# gof imports # gof imports
from theano.gof import graph from theano.gof import graph
...@@ -432,7 +432,7 @@ class CLinker(link.Linker): ...@@ -432,7 +432,7 @@ class CLinker(link.Linker):
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
"""WRITEME""" """WRITEME"""
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)().accept(fgraph, no_recycling) return type(self)().accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already" #raise Exception("Cannot accept from a Linker that is already"
...@@ -533,7 +533,7 @@ class CLinker(link.Linker): ...@@ -533,7 +533,7 @@ class CLinker(link.Linker):
if isinstance(variable, graph.Constant): if isinstance(variable, graph.Constant):
try: try:
symbol[variable] = ("(" + variable.type.c_literal( symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")") variable.data) + ")")
self.consts.append(variable) self.consts.append(variable)
self.orphans.remove(variable) self.orphans.remove(variable)
continue continue
...@@ -635,8 +635,8 @@ class CLinker(link.Linker): ...@@ -635,8 +635,8 @@ class CLinker(link.Linker):
else: else:
# The following will be executed if the "try" block succeeds # The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], basestring), ( assert isinstance(c_support_code_apply[-1], basestring), (
str(node.op) + str(node.op) +
" didn't return a string for c_support_code_apply") " didn't return a string for c_support_code_apply")
try: try:
c_init_code_apply.append(op.c_init_code_apply(node, name)) c_init_code_apply.append(op.c_init_code_apply(node, name))
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -1041,14 +1041,15 @@ class CLinker(link.Linker): ...@@ -1041,14 +1041,15 @@ class CLinker(link.Linker):
no_recycle list. no_recycle list.
""" """
return self.cmodule_key_(self.fgraph, self.no_recycling, return self.cmodule_key_(self.fgraph, self.no_recycling,
compile_args=self.compile_args(), compile_args=self.compile_args(),
libraries=self.libraries(), libraries=self.libraries(),
header_dirs=self.header_dirs(), header_dirs=self.header_dirs(),
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None, def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
header_dirs=None, insert_config_md5=True, c_compiler=None): libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
""" """
Do the actual computation of cmodule_key in a static method Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__ to allow it to be reused in scalar.Composite.__eq__
...@@ -1064,7 +1065,7 @@ class CLinker(link.Linker): ...@@ -1064,7 +1065,7 @@ class CLinker(link.Linker):
# seen 'so far' in the loop below # seen 'so far' in the loop below
fgraph_computed_set = set() fgraph_computed_set = set()
fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in
enumerate(fgraph.inputs)) enumerate(fgraph.inputs))
constant_ids = dict() constant_ids = dict()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
...@@ -1172,7 +1173,7 @@ class CLinker(link.Linker): ...@@ -1172,7 +1173,7 @@ class CLinker(link.Linker):
sig.append(( sig.append((
node.op, node.op,
tuple((i.type, in_sig(i, node_pos, ipos)) tuple((i.type, in_sig(i, node_pos, ipos))
for ipos, i in enumerate(node.inputs)), for ipos, i in enumerate(node.inputs)),
(1, # Increment if cmodule change its handling of outputs (1, # Increment if cmodule change its handling of outputs
tuple(o in no_recycling for o in node.outputs)))) tuple(o in no_recycling for o in node.outputs))))
...@@ -1468,10 +1469,10 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1468,10 +1469,10 @@ class OpWiseCLinker(link.LocalLinker):
__cache__ = {} __cache__ = {}
def __init__(self, def __init__(self,
fallback_on_perform=True, fallback_on_perform=True,
allow_gc=None, allow_gc=None,
nice_errors=True, nice_errors=True,
schedule=None): schedule=None):
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
self.fgraph = None self.fgraph = None
...@@ -1486,10 +1487,10 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1486,10 +1487,10 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)( return type(self)(
fallback_on_perform=self.fallback_on_perform, fallback_on_perform=self.fallback_on_perform,
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
nice_errors=self.nice_errors nice_errors=self.nice_errors
).accept(fgraph, no_recycling) ).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is #raise Exception("Cannot accept from a Linker that is
#already tied to another FunctionGraph.") #already tied to another FunctionGraph.")
self.fgraph = fgraph self.fgraph = fgraph
...@@ -1510,7 +1511,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1510,7 +1511,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage)
if self.allow_gc: if self.allow_gc:
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
post_thunk_old_storage = [] post_thunk_old_storage = []
...@@ -1533,9 +1534,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1533,9 +1534,9 @@ class OpWiseCLinker(link.LocalLinker):
if theano.config.cxx: if theano.config.cxx:
node.op._op_use_c_code = True node.op._op_use_c_code = True
thunks += [node.op.make_thunk(node, thunks += [node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
no_recycling)] no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs]
...@@ -1558,9 +1559,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1558,9 +1559,9 @@ class OpWiseCLinker(link.LocalLinker):
for r in no_recycling if r not in fgraph.inputs] for r in no_recycling if r not in fgraph.inputs]
f = link.streamline(fgraph, thunks, order, f = link.streamline(fgraph, thunks, order,
post_thunk_old_storage, post_thunk_old_storage,
no_recycling=no_recycling, no_recycling=no_recycling,
nice_errors=self.nice_errors) nice_errors=self.nice_errors)
f.allow_gc = self.allow_gc f.allow_gc = self.allow_gc
......
...@@ -3,8 +3,9 @@ import unittest ...@@ -3,8 +3,9 @@ import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano
from theano.gof.link import PerformLinker from theano.gof.link import PerformLinker
from theano.gof.cc import * from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
...@@ -265,6 +266,7 @@ def test_opwiseclinker_straightforward(): ...@@ -265,6 +266,7 @@ def test_opwiseclinker_straightforward():
# The python version of bad_sub always return -10. # The python version of bad_sub always return -10.
assert fn(2.0, 2.0, 2.0) == -6 assert fn(2.0, 2.0, 2.0) == -6
def test_opwiseclinker_constant(): def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name='x') x = Constant(tdouble, 7.2, name='x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论