提交 7f6a333b authored 作者: goodfeli's avatar goodfeli

Merge pull request #83 from dwf/misc_pep8

Misc style fixes in op.py
......@@ -15,8 +15,6 @@ __docformat__ = "restructuredtext en"
import logging
import warnings
import numpy
import theano
from theano import config
......@@ -142,7 +140,6 @@ class CLinkerObject(object):
"""
return self.c_code_cache_version()
def c_compile_args(self):
"""Optional: Return a list of compile args recommended to compile the
code returned by other methods in this class.
......@@ -176,6 +173,7 @@ class CLinkerObject(object):
"""
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__)
class CLinkerOp(CLinkerObject):
"""
Interface definition for `Op` subclasses compiled by `CLinker`.
......@@ -415,7 +413,7 @@ class PureOp(object):
no_recycling=[])
required = thunk()
assert not required # We provided all inputs
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
......@@ -469,6 +467,7 @@ class PureOp(object):
"""
raise utils.MethodNotDefined("perform", type(self), self.__class__.__name__)
class Op(utils.object2, PureOp, CLinkerOp):
"""Convenience class to bundle `PureOp` and `CLinkerOp`"""
def __new__(cls, *args, **kwargs):
......@@ -521,13 +520,15 @@ class Op(utils.object2, PureOp, CLinkerOp):
no_recycling=e_no_recycling)
logger.debug('Trying CLinker.make_thunk')
fill_storage, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
outputs = cl.make_thunk(input_storage=node_input_storage,
output_storage=node_output_storage)
fill_storage, node_input_filters, node_output_filters = outputs
def rval():
fill_storage()
for o in node.outputs:
compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk
rval.inputs = node_input_storage
rval.outputs = node_output_storage
......@@ -540,11 +541,13 @@ class Op(utils.object2, PureOp, CLinkerOp):
p = node.op.perform
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
......@@ -567,6 +570,7 @@ def get_test_value(v):
v_tensor = theano.tensor.as_tensor_variable(v)
return PureOp._get_test_value(v_tensor)
def missing_test_message(msg):
""" Displays msg, a message saying that some test_value is missing,
in the appropriate form based on config.compute_test_value:
......@@ -577,16 +581,14 @@ def missing_test_message(msg):
warn: display msg as a warning
raise: raise an AttributeError with msg as the exception text
"""
action = config.compute_test_value
if action == 'raise':
raise AttributeError(msg)
elif action == 'warn':
warnings.warn(msg, stacklevel = 2)
warnings.warn(msg, stacklevel=2)
else:
assert action in [ 'ignore', 'off' ]
assert action in ['ignore', 'off']
def debug_error_message(msg):
""" Displays a message saying that an error was found in some
......@@ -598,28 +600,24 @@ def debug_error_message(msg):
#this message should never be called when the debugger is off
assert action != 'off'
if action in ['raise','ignore']:
if action in ['raise', 'ignore']:
raise ValueError(msg)
else:
assert action == 'warn'
warnings.warn(msg, stacklevel = 2)
warnings.warn(msg, stacklevel=2)
def debug_assert(condition, msg):
def debug_assert(condition, msg):
if not condition:
action = config.compute_test_value
if action in ['raise', 'ignore']:
raise AssertionError(msg)
else:
assert action == 'warn'
warnings.warn(msg, stacklevel = 2)
warnings.warn(msg, stacklevel=2)
def get_debug_values(*args):
""" Given a list of variables, does one of three things:
1. If the interactive debugger is off, returns an empty list
......@@ -649,9 +647,11 @@ def get_debug_values(*args):
rval.append(get_test_value(arg))
except AttributeError:
if hasattr(arg, 'name') and arg.name is not None:
missing_test_message("Argument " + str(i) + "('" + arg.name + "') has no test value")
missing_test_message("Argument " + str(i) + "('" + arg.name +
"') has no test value")
else:
missing_test_message("Argument " + str(i) + " has no test value")
missing_test_message("Argument " + str(i) +
" has no test value")
return []
return [ tuple(rval) ]
return [tuple(rval)]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论