提交 c846d395 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5918 from notoraptor/update-c-support-code-and-params-type

Allow c_support_code() to support list of strings and update ParamsType
......@@ -151,7 +151,7 @@ There are less methods to define for an Op than for a Type:
.. method:: c_support_code()
Allows you to specify helper functions/structs that the
Allows you to specify helper functions/structs (in a string or a list of string) that the
:ref:`op` needs. That code will be reused for each apply of
this op. It will be inserted at global scope.
......
......@@ -114,7 +114,7 @@ the most important ones:
.. method:: c_support_code()
Allows to add helper functions/structs that the :ref:`type` needs.
Allows to add helper functions/structs (in a string or a list of strings) that the :ref:`type` needs.
.. method:: c_compiler()
......
......@@ -369,7 +369,7 @@ commonly used.
.. method:: c_support_code()
Returns a string containing some support C code for this op. This code
Returns a string or a list of strings containing some support C code for this op. This code
will be included at the global scope level and can be used to define
functions and structs that will be used by every apply of this op.
......
......@@ -914,7 +914,11 @@ class CLinker(link.Linker):
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
try:
ret.append(x.c_support_code())
support_code = x.c_support_code()
if isinstance(support_code, list):
ret.extend(support_code)
else:
ret.append(support_code)
except utils.MethodNotDefined:
pass
return ret
......
......@@ -152,7 +152,7 @@ class CLinkerObject(object):
def c_support_code(self):
"""
Optional: Return utility code for use by a `Variable` or `Op` to be
Optional: Return utility code (a string, or a list of strings) for use by a `Variable` or `Op` to be
included at global scope prior to the rest of the code for this class.
QUESTION: How many times will this support code be emitted for a graph
......@@ -802,12 +802,10 @@ class Op(utils.object2, PureOp, CLinkerOp):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.ParamsType):
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
raise AttributeError('%s: missing attributes for ParamsType parameter.' % type(self).__name__)
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
_type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
# Let's print missing attributes for debugging.
not_found = tuple(field for field in wrapper.fields if not hasattr(self, field))
raise AttributeError('%s: missing attributes %s for ParamsType.' % (type(self).__name__, not_found))
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise theano.gof.utils.MethodNotDefined('get_params')
......
......@@ -256,12 +256,25 @@ class ParamsType(Type):
if [a for a in all_aliases if a in all_enums]:
raise AttributeError('ParamsType: found aliases that have same names as constants.')
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in Wrapper object directly.
# We will then use this dict to find enum value when looking for enum name in ParamsType object directly.
self.__const_to_enum = {enum_name: enum_type for enum_type in enum_types for enum_name in enum_type}
self.__alias_to_enum = {alias: enum_type for enum_type in enum_types for alias in enum_type.aliases}
def __setstate__(self, state):
# NB:
# I have overridden __getattr__ to make enum constants available through
# the ParamsType when it contains enum types. To do that, I use some internal
# attributes: self.__const_to_enum and self.__alias_to_enum. These attributes
# are normally found by Python without need to call getattr(), but when the
# ParamsType is unpickled, it seems gettatr() may be called at a point before
# __const_to_enum or __alias_to_enum are unpickled, so that gettatr() can't find
# those attributes, and then loop infinitely.
# For this reason, I must add this trivial implementation of __setstate__()
# to avoid errors when unpickling.
self.__dict__.update(state)
def __getattr__(self, key):
# Now we can access value of each enum defined inside enum types wrapped into the current Wrapper.
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if key in self.__const_to_enum:
return self.__const_to_enum[key][key]
return super(ParamsType, self).__getattr__(self, key)
......@@ -293,6 +306,14 @@ class ParamsType(Type):
"""
return theano_type in self.types
def get_type(self, field_name):
"""
Return the Theano type associated to the given field name
in the current ParamsType.
"""
return self.types[self.fields.index(field_name)]
def get_field(self, theano_type):
"""
Return the name (string) of the first field associated to
......@@ -427,6 +448,18 @@ class ParamsType(Type):
for i in range(self.length)}
return Params(self, **filtered)
def extended(self, **kwargs):
"""
Return a copy of current ParamsType
extended with attributes given in kwargs.
New attributes must follow same rules as in
ParamsType constructor.
"""
self_to_dict = {self.fields[i]: self.types[i] for i in range(self.length)}
self_to_dict.update(kwargs)
return ParamsType(**self_to_dict)
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Params):
......@@ -531,7 +564,11 @@ class ParamsType(Type):
for attribute_name, type_instance in zip(self.fields, self.types):
try:
c_support_code_set.add(type_instance.c_support_code())
# c_support_code() may return a code string or a list of code strings.
support_code = type_instance.c_support_code()
if not isinstance(support_code, list):
support_code = [support_code]
c_support_code_set.update(support_code)
except MethodNotDefined:
pass
......@@ -550,7 +587,6 @@ class ParamsType(Type):
'extract_code': type_instance.c_extract(attribute_name, sub)
})
support_code = '\n'.join(sorted(list(c_support_code_set)))
struct_declare = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list)
......@@ -570,8 +606,8 @@ class ParamsType(Type):
""" % ('\n'.join(
[('case %d: extract_%s(object); break;' % (i, self.fields[i])) for i in range(self.length)])
)
return """
%(support_code)s
final_struct_code = """
/** ParamsType %(struct_name)s **/
#ifndef %(struct_name_defined)s
#define %(struct_name_defined)s
struct %(struct_name)s {
......@@ -611,13 +647,15 @@ class ParamsType(Type):
}
};
#endif
""" % dict(support_code=support_code,
struct_name_defined=struct_name_defined, struct_name=struct_name, struct_declare=struct_declare,
/** End ParamsType %(struct_name)s **/
""" % dict(struct_name_defined=struct_name_defined, struct_name=struct_name, struct_declare=struct_declare,
struct_init=struct_init, struct_cleanup=struct_cleanup, struct_extract=struct_extract,
struct_extract_method=struct_extract_method)
return list(sorted(list(c_support_code_set))) + [final_struct_code]
def c_code_cache_version(self):
return ((1, 8), tuple(t.c_code_cache_version() for t in self.types))
return ((2,), tuple(t.c_code_cache_version() for t in self.types))
# As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论