提交 17835f40 authored 作者: notoraptor's avatar notoraptor

Update ParamsType:

- return a list of strings for c_support_code(). - add utility functions (to be used in other PRs). - Fix typos.
上级 1776fb3a
......@@ -256,12 +256,25 @@ class ParamsType(Type):
if [a for a in all_aliases if a in all_enums]:
raise AttributeError('ParamsType: found aliases that have same names as constants.')
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in Wrapper object directly.
# We will then use this dict to find enum value when looking for enum name in ParamsType object directly.
self.__const_to_enum = {enum_name: enum_type for enum_type in enum_types for enum_name in enum_type}
self.__alias_to_enum = {alias: enum_type for enum_type in enum_types for alias in enum_type.aliases}
def __setstate__(self, state):
# NB:
# I have overridden __getattr__ to make enum constants available through
# the ParamsType when it contains enum types. To do that, I use some internal
# attributes: self.__const_to_enum and self.__alias_to_enum. These attributes
# are normally found by Python without need to call getattr(), but when the
# ParamsType is unpickled, it seems gettatr() may be called at a point before
# __const_to_enum or __alias_to_enum are unpickled, so that gettatr() can't find
# those attributes, and then loop infinitely.
# For this reason, I must add this trivial implementation of __setstate__()
# to avoid errors when unpickling.
self.__dict__.update(state)
def __getattr__(self, key):
# Now we can access value of each enum defined inside enum types wrapped into the current Wrapper.
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if key in self.__const_to_enum:
return self.__const_to_enum[key][key]
return super(ParamsType, self).__getattr__(self, key)
......@@ -293,6 +306,14 @@ class ParamsType(Type):
"""
return theano_type in self.types
def get_type(self, field_name):
"""
Return the Theano type associated to the given field name
in the current ParamsType.
"""
return self.types[self.fields.index(field_name)]
def get_field(self, theano_type):
"""
Return the name (string) of the first field associated to
......@@ -427,6 +448,18 @@ class ParamsType(Type):
for i in range(self.length)}
return Params(self, **filtered)
def extended(self, **kwargs):
"""
Return a copy of current ParamsType
extended with attributes given in kwargs.
New attributes must follow same rules as in
ParamsType constructor.
"""
self_to_dict = {self.fields[i]: self.types[i] for i in range(self.length)}
self_to_dict.update(kwargs)
return ParamsType(**self_to_dict)
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Params):
......@@ -531,7 +564,11 @@ class ParamsType(Type):
for attribute_name, type_instance in zip(self.fields, self.types):
try:
c_support_code_set.add(type_instance.c_support_code())
# c_support_code() may return a code string or a list of code strings.
support_code = type_instance.c_support_code()
if not isinstance(support_code, list):
support_code = [support_code]
c_support_code_set.update(support_code)
except MethodNotDefined:
pass
......@@ -550,7 +587,6 @@ class ParamsType(Type):
'extract_code': type_instance.c_extract(attribute_name, sub)
})
support_code = '\n'.join(sorted(list(c_support_code_set)))
struct_declare = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list)
......@@ -570,8 +606,8 @@ class ParamsType(Type):
""" % ('\n'.join(
[('case %d: extract_%s(object); break;' % (i, self.fields[i])) for i in range(self.length)])
)
return """
%(support_code)s
final_struct_code = """
/** ParamsType %(struct_name)s **/
#ifndef %(struct_name_defined)s
#define %(struct_name_defined)s
struct %(struct_name)s {
......@@ -611,11 +647,13 @@ class ParamsType(Type):
}
};
#endif
""" % dict(support_code=support_code,
struct_name_defined=struct_name_defined, struct_name=struct_name, struct_declare=struct_declare,
/** End ParamsType %(struct_name)s **/
""" % dict(struct_name_defined=struct_name_defined, struct_name=struct_name, struct_declare=struct_declare,
struct_init=struct_init, struct_cleanup=struct_cleanup, struct_extract=struct_extract,
struct_extract_method=struct_extract_method)
return list(sorted(list(c_support_code_set))) + [final_struct_code]
def c_code_cache_version(self):
return ((1, 8), tuple(t.c_code_cache_version() for t in self.types))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论