提交 4936e630 authored 作者: Ramana.S's avatar Ramana.S

Test added, code made complaint with pep8 standards

上级 db6fc48d
...@@ -509,7 +509,7 @@ class FromFunctionOp(gof.Op): ...@@ -509,7 +509,7 @@ class FromFunctionOp(gof.Op):
self.infer_shape = self._infer_shape self.infer_shape = self._infer_shape
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (isinstance(self, type(other)) and
self.__fn == other.__fn) self.__fn == other.__fn)
def __hash__(self): def __hash__(self):
...@@ -523,7 +523,7 @@ class FromFunctionOp(gof.Op): ...@@ -523,7 +523,7 @@ class FromFunctionOp(gof.Op):
if not self.itypes: if not self.itypes:
raise NotImplementedError("itypes not defined") raise NotImplementedError("itypes not defined")
if not self.otypes : if not self.otypes:
raise NotImplementedError("otypes not defined") raise NotImplementedError("otypes not defined")
if len(inputs) != len(self.itypes): if len(inputs) != len(self.itypes):
......
...@@ -16,7 +16,7 @@ import pickle ...@@ -16,7 +16,7 @@ import pickle
# reachable from pickle (as in it cannot be defined inline) # reachable from pickle (as in it cannot be defined inline)
@as_op([dmatrix, dmatrix], dmatrix) @as_op([dmatrix, dmatrix], dmatrix)
def mul(a, b): def mul(a, b):
return a*b return a * b
class OpDecoratorTests(utt.InferShapeTester): class OpDecoratorTests(utt.InferShapeTester):
......
...@@ -38,6 +38,7 @@ class CLinkerObject(object): ...@@ -38,6 +38,7 @@ class CLinkerObject(object):
Standard elements of an Op or Type used with the CLinker. Standard elements of an Op or Type used with the CLinker.
""" """
def c_headers(self): def c_headers(self):
""" """
Optional: Return a list of header files required by code returned by Optional: Return a list of header files required by code returned by
...@@ -59,7 +60,8 @@ class CLinkerObject(object): ...@@ -59,7 +60,8 @@ class CLinkerObject(object):
Subclass does not implement this method. Subclass does not implement this method.
""" """
raise utils.MethodNotDefined("c_headers", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_headers", type(self), self.__class__.__name__)
def c_header_dirs(self): def c_header_dirs(self):
""" """
...@@ -82,7 +84,10 @@ class CLinkerObject(object): ...@@ -82,7 +84,10 @@ class CLinkerObject(object):
Subclass does not implement this method. Subclass does not implement this method.
""" """
raise utils.MethodNotDefined("c_header_dirs", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_header_dirs",
type(self),
self.__class__.__name__)
def c_libraries(self): def c_libraries(self):
""" """
...@@ -105,7 +110,8 @@ class CLinkerObject(object): ...@@ -105,7 +110,8 @@ class CLinkerObject(object):
Subclass does not implement this method. Subclass does not implement this method.
""" """
raise utils.MethodNotDefined("c_libraries", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_libraries", type(self), self.__class__.__name__)
def c_lib_dirs(self): def c_lib_dirs(self):
""" """
...@@ -128,7 +134,8 @@ class CLinkerObject(object): ...@@ -128,7 +134,8 @@ class CLinkerObject(object):
Subclass does not implement this method. Subclass does not implement this method.
""" """
raise utils.MethodNotDefined("c_lib_dirs", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_lib_dirs", type(self), self.__class__.__name__)
def c_support_code(self): def c_support_code(self):
""" """
...@@ -144,7 +151,10 @@ class CLinkerObject(object): ...@@ -144,7 +151,10 @@ class CLinkerObject(object):
Subclass does not implement this method. Subclass does not implement this method.
""" """
raise utils.MethodNotDefined("c_support_code", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_support_code",
type(self),
self.__class__.__name__)
def c_code_cache_version(self): def c_code_cache_version(self):
""" """
...@@ -182,7 +192,10 @@ class CLinkerObject(object): ...@@ -182,7 +192,10 @@ class CLinkerObject(object):
Subclass does not implement this method. Subclass does not implement this method.
""" """
raise utils.MethodNotDefined("c_compile_args", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_compile_args",
type(self),
self.__class__.__name__)
def c_no_compile_args(self): def c_no_compile_args(self):
""" """
...@@ -203,7 +216,10 @@ class CLinkerObject(object): ...@@ -203,7 +216,10 @@ class CLinkerObject(object):
The subclass does not override this method. The subclass does not override this method.
""" """
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"c_no_compile_args",
type(self),
self.__class__.__name__)
def c_init_code(self): def c_init_code(self):
""" """
...@@ -512,7 +528,8 @@ class PureOp(object): ...@@ -512,7 +528,8 @@ class PureOp(object):
MethodNotDefined : the subclass does not override this method. MethodNotDefined : the subclass does not override this method.
""" """
raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"make_node", type(self), self.__class__.__name__)
@classmethod @classmethod
def _get_test_value(cls, v): def _get_test_value(cls, v):
...@@ -542,7 +559,7 @@ class PureOp(object): ...@@ -542,7 +559,7 @@ class PureOp(object):
"For compute_test_value, one input test value does not" "For compute_test_value, one input test value does not"
" have the requested type.\n") " have the requested type.\n")
tr = getattr(v.tag, 'trace', []) tr = getattr(v.tag, 'trace', [])
if type(tr) is list and len(tr) > 0: if isinstance(tr, list) and len(tr) > 0:
detailed_err_msg += ( detailed_err_msg += (
" \nBacktrace when that variable is created:\n") " \nBacktrace when that variable is created:\n")
# Print separate message for each element in the list # Print separate message for each element in the list
...@@ -612,10 +629,14 @@ class PureOp(object): ...@@ -612,10 +629,14 @@ class PureOp(object):
except AttributeError: except AttributeError:
# no test-value was specified, act accordingly # no test-value was specified, act accordingly
if config.compute_test_value == 'warn': if config.compute_test_value == 'warn':
warnings.warn('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node), stacklevel=2) warnings.warn(
'Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' %
(i, ins, node), stacklevel=2)
run_perform = False run_perform = False
elif config.compute_test_value == 'raise': elif config.compute_test_value == 'raise':
raise ValueError('Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node)) raise ValueError(
'Cannot compute test value: input %i (%s) of Op %s missing default value' %
(i, ins, node))
elif config.compute_test_value == 'ignore': elif config.compute_test_value == 'ignore':
# silently skip test # silently skip test
run_perform = False run_perform = False
...@@ -623,7 +644,9 @@ class PureOp(object): ...@@ -623,7 +644,9 @@ class PureOp(object):
import pdb import pdb
pdb.post_mortem(sys.exc_info()[2]) pdb.post_mortem(sys.exc_info()[2])
else: else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value) raise ValueError(
'%s is invalid for option config.compute_Test_value' %
config.compute_test_value)
# if all inputs have test-values, run the actual op # if all inputs have test-values, run the actual op
if run_perform: if run_perform:
...@@ -653,7 +676,8 @@ class PureOp(object): ...@@ -653,7 +676,8 @@ class PureOp(object):
for output in node.outputs: for output in node.outputs:
# Check that the output has been computed # Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0]) assert compute_map[output][
0], (output, storage_map[output][0])
# add 'test_value' to output tag, so that downstream ops can use these # add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method. # numerical values as inputs to their perform method.
...@@ -746,7 +770,8 @@ class PureOp(object): ...@@ -746,7 +770,8 @@ class PureOp(object):
The subclass does not override this method. The subclass does not override this method.
""" """
raise utils.MethodNotDefined("perform", type(self), self.__class__.__name__) raise utils.MethodNotDefined(
"perform", type(self), self.__class__.__name__)
def do_constant_folding(self, node): def do_constant_folding(self, node):
""" """
...@@ -810,7 +835,8 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -810,7 +835,8 @@ class Op(utils.object2, PureOp, CLinkerOp):
def __eq__(self, other): def __eq__(self, other):
if hasattr(self, '__props__'): if hasattr(self, '__props__'):
return (type(self) == type(other) and self._props() == other._props()) return (isinstance(self, type(other))
and self._props() == other._props())
else: else:
return NotImplemented return NotImplemented
...@@ -938,7 +964,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -938,7 +964,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
if not hasattr(self, 'itypes'): if not hasattr(self, 'itypes'):
raise NotImplementedError("itypes not defined") raise NotImplementedError("itypes not defined")
if not hasattr(self, 'otypes') : if not hasattr(self, 'otypes'):
raise NotImplementedError("otypes not defined") raise NotImplementedError("otypes not defined")
if len(inputs) != len(self.itypes): if len(inputs) != len(self.itypes):
...@@ -1220,7 +1246,9 @@ class COp(Op): ...@@ -1220,7 +1246,9 @@ class COp(Op):
""" """
section_re = re.compile(r'^#section ([a-zA-Z0-9_]+)$', re.MULTILINE) section_re = re.compile(r'^#section ([a-zA-Z0-9_]+)$', re.MULTILINE)
backward_re = re.compile(r'^THEANO_(APPLY|SUPPORT)_CODE_SECTION$', re.MULTILINE) backward_re = re.compile(
r'^THEANO_(APPLY|SUPPORT)_CODE_SECTION$',
re.MULTILINE)
# This is the set of allowed markers # This is the set of allowed markers
SECTIONS = set([ SECTIONS = set([
'init_code', 'init_code_apply', 'init_code_struct', 'init_code', 'init_code_apply', 'init_code_struct',
...@@ -1320,7 +1348,8 @@ class COp(Op): ...@@ -1320,7 +1348,8 @@ class COp(Op):
n = 1 n = 1
while n < len(split): while n < len(split):
if split[n] not in self.SECTIONS: if split[n] not in self.SECTIONS:
raise ValueError("Unknown section type (in file %s): %s" % raise ValueError(
"Unknown section type (in file %s): %s" %
(self.func_files[i], split[n])) (self.func_files[i], split[n]))
if split[n] not in self.code_sections: if split[n] not in self.code_sections:
self.code_sections[split[n]] = "" self.code_sections[split[n]] = ""
...@@ -1377,8 +1406,9 @@ class COp(Op): ...@@ -1377,8 +1406,9 @@ class COp(Op):
if check_input: if check_input:
# Extract the various properties of the input and output variables # Extract the various properties of the input and output variables
variables = node.inputs + node.outputs variables = node.inputs + node.outputs
variable_names = (["INPUT_%i" % i for i in range(len(node.inputs))] + variable_names = (["INPUT_%i" %
["OUTPUT_%i" % i for i in range(len(node.inputs))]) i for i in range(len(node.inputs))] + ["OUTPUT_%i" %
i for i in range(len(node.inputs))])
# Generate dtype macros # Generate dtype macros
for i, v in enumerate(variables): for i, v in enumerate(variables):
...@@ -1389,7 +1419,9 @@ class COp(Op): ...@@ -1389,7 +1419,9 @@ class COp(Op):
macro_name = "DTYPE_" + vname macro_name = "DTYPE_" + vname
macro_value = "npy_" + v.dtype macro_value = "npy_" + v.dtype
define_macros.append(define_template % (macro_name, macro_value)) define_macros.append(
define_template %
(macro_name, macro_value))
undef_macros.append(undef_template % macro_name) undef_macros.append(undef_template % macro_name)
d = numpy.dtype(v.dtype) d = numpy.dtype(v.dtype)
...@@ -1397,13 +1429,17 @@ class COp(Op): ...@@ -1397,13 +1429,17 @@ class COp(Op):
macro_name = "TYPENUM_" + vname macro_name = "TYPENUM_" + vname
macro_value = d.num macro_value = d.num
define_macros.append(define_template % (macro_name, macro_value)) define_macros.append(
define_template %
(macro_name, macro_value))
undef_macros.append(undef_template % macro_name) undef_macros.append(undef_template % macro_name)
macro_name = "ITEMSIZE_" + vname macro_name = "ITEMSIZE_" + vname
macro_value = d.itemsize macro_value = d.itemsize
define_macros.append(define_template % (macro_name, macro_value)) define_macros.append(
define_template %
(macro_name, macro_value))
undef_macros.append(undef_template % macro_name) undef_macros.append(undef_template % macro_name)
# Generate a macro to mark code as being apply-specific # Generate a macro to mark code as being apply-specific
......
...@@ -29,7 +29,7 @@ class MyType(Type): ...@@ -29,7 +29,7 @@ class MyType(Type):
self.thingy = thingy self.thingy = thingy
def __eq__(self, other): def __eq__(self, other):
return type(other) == type(self) and other.thingy == self.thingy return isinstance(other, type(self)) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return str(self.thingy)
...@@ -157,6 +157,7 @@ class TestOp: ...@@ -157,6 +157,7 @@ class TestOp:
class TestMakeThunk(unittest.TestCase): class TestMakeThunk(unittest.TestCase):
def test_no_c_code(self): def test_no_c_code(self):
class IncOnePython(Op): class IncOnePython(Op):
"""An Op with only a Python (perform) implementation""" """An Op with only a Python (perform) implementation"""
...@@ -234,6 +235,28 @@ class TestMakeThunk(unittest.TestCase): ...@@ -234,6 +235,28 @@ class TestMakeThunk(unittest.TestCase):
self.assertRaises((NotImplementedError, utils.MethodNotDefined), self.assertRaises((NotImplementedError, utils.MethodNotDefined),
thunk) thunk)
def test_no_make_node(self):
class IncOne(Op):
"""An Op without make_node"""
__props__ = ()
itypes = [T.fmatrix]
otypes = [T.fmatrix]
def perform(self, node, inputs, outputs):
input, = inputs
output, = outputs
output[0] = input + 1
x_input = T.fmatrix('x')
o = IncOne()(x_input)
# Confirming that make_node method is implemented
try:
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.make_node, x_input)
except AssertionError:
pass
def test_test_value_python_objects(): def test_test_value_python_objects():
for x in ([0, 1, 2], 0, 0.5, 1): for x in ([0, 1, 2], 0, 0.5, 1):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论