提交 42e7ed18 authored 作者: notoraptor's avatar notoraptor

Extend ParamsType: handle enumerations types.

Add public: * has_type() * get_field() * get_enum() * enum_from_alias() * get_params() Update documentation for ParamsType.
上级 87c35457
...@@ -12,4 +12,5 @@ Reference ...@@ -12,4 +12,5 @@ Reference
:platform: Unix, Windows :platform: Unix, Windows
:synopsis: Wrapper class for op params :synopsis: Wrapper class for op params
:members: :members:
.. moduleauthor:: LISA :member-order: bysource
\ No newline at end of file .. moduleauthor:: LISA
...@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``): ...@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py`` See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
for complete working examples. for complete working examples.
Combining ParamsType with Theano enumeration types
--------------------------------------------------
Theano provide some enumeration types that allow to create constant primitive values (integer and floating values)
available in both Python and C code. See :class:`theano.gof.type.EnumType` and its subclasses for more details.
If your ParamsType contains Theano enumeration types, then constants defined inside these
enumerations will be directly available as ParamsType attributes.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3'),
enum2=EnumType(PI=3.14, EPSILON=0.001))
# Each enum constant is available as a wrapper attribute:
print(wrapper.CONSTANT_1, wrapper.CONSTANT_2, wrapper.CONSTANT_3,
wrapper.PI, wrapper.EPSILON)
# For convenience, you can also look for a constant by name with
# ``ParamsType.get_enum()`` method.
pi = wrapper.get_enum('PI')
epsilon = wrapper.get_enum('EPSILON')
constant_2 = wrapper.get_enum('CONSTANT_2')
print(pi, epsilon, constant_2)
This implies that a ParamsType cannot contain different enum types with common enum names::
# Following line will raise an error,
# as there is a "CONSTANT_1" defined both in enum1 and enum2.
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2'),
enum2=EnumType(CONSTANT_1=0, CONSTANT_3=5))
If your enum types contain constant aliases, you can retrive them from ParamsType
with ``ParamsType.enum_from_alias(alias)`` method (see :class:`theano.gof.type.EnumType`
for more info about enumeration aliases).
.. code-block:: python
wrapper = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'),
enum2=EnumList(('D', 'delta'), 'E', 'F'))
b1 = wrapper.B
b2 = wrapper.get_enum('B')
b3 = wrapper.enum_from_alias('beta')
assert b1 == b2 == b3
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import re import re
import hashlib import hashlib
from theano.gof.utils import MethodNotDefined, c_cpp_keywords from theano.gof.utils import MethodNotDefined, c_cpp_keywords
from theano.gof import Type from theano.gof import Type, EnumType
class Params(dict): class Params(dict):
...@@ -191,7 +238,28 @@ class ParamsType(Type): ...@@ -191,7 +238,28 @@ class ParamsType(Type):
self.length = len(kwargs) self.length = len(kwargs)
self.fields = tuple(sorted(kwargs.keys())) self.fields = tuple(sorted(kwargs.keys()))
self.types = tuple(kwargs[field] for field in self.fields) self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name() self.name = self.__generate_struct_name()
self.__const_to_enum = {}
self.__alias_to_enum = {}
enum_types = [t for t in self.types if isinstance(t, EnumType)]
if enum_types:
# We don't want same enum names in different enum types.
if sum(len(t) for t in enum_types) != len(set(k for t in enum_types for k in t)):
raise AttributeError('Wrapper: found different enum types with common constant names.')
# We don't want same aliases in different enum types.
if sum(len(t.aliases) for t in enum_types) != len(set(alias for t in enum_types for alias in t.aliases)):
raise AttributeError('Wrapper: found different enum types with common constant aliases.')
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in Wrapper object directly.
self.__const_to_enum = {enum_name: enum_type for enum_type in enum_types for enum_name in enum_type}
self.__alias_to_enum = {alias: enum_type for enum_type in enum_types for alias in enum_type.aliases}
def __getattr__(self, key):
# Now we can access value of each enum defined inside enum types wrapped into the current Wrapper.
if key in self.__const_to_enum:
return self.__const_to_enum[key][key]
return super(ParamsType, self).__getattr__(self, key)
def __repr__(self): def __repr__(self):
return 'ParamsType<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)]) return 'ParamsType<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)])
...@@ -202,7 +270,7 @@ class ParamsType(Type): ...@@ -202,7 +270,7 @@ class ParamsType(Type):
def __hash__(self): def __hash__(self):
return hash((type(self),) + self.fields + self.types) return hash((type(self),) + self.fields + self.types)
def generate_struct_name(self): def __generate_struct_name(self):
# This method tries to generate an unique name for the current instance. # This method tries to generate an unique name for the current instance.
# This name is intended to be used as struct name in C code and as constant # This name is intended to be used as struct name in C code and as constant
# definition to check if a similar ParamsType has already been created # definition to check if a similar ParamsType has already been created
...@@ -213,6 +281,147 @@ class ParamsType(Type): ...@@ -213,6 +281,147 @@ class ParamsType(Type):
types_hex = hashlib.md5(types_string).hexdigest() types_hex = hashlib.md5(types_string).hexdigest()
return '_Params_%s_%s' % (fields_hex, types_hex) return '_Params_%s_%s' % (fields_hex, types_hex)
def has_type(self, theano_type):
"""
Return True if current ParamsType contains the specified Theano type.
"""
return theano_type in self.types
def get_field(self, theano_type):
"""
Return the name (string) of the first field associated to
the given Theano type. Fields are sorted in lexicographic
order. Raise an exception if this Theano type is not
in the current ParamsType.
This method is intended to be used to retrieve a field name
when we know that current ParamsType contains the given
Theano type only once.
"""
return self.fields[self.types.index(theano_type)]
def get_enum(self, key):
"""
Look for a constant named ``key`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either the constant is not found or
current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=1, B=2, C=3),
digits=EnumList('ZERO', 'ONE', 'TWO'))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
# You can also directly do:
print(wrapper.C)
print(wrapper.TWO)
"""
return self.__const_to_enum[key][key]
def enum_from_alias(self, alias):
"""
Look for a constant that has alias ``alias`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either
1. there is no constant with this alias,
2. there is no constant which name is this alias, or
3. current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=(1, 'alpha'), B=(2, 'beta'), C=3),
digits=EnumList(('ZERO', 'nothing'), ('ONE', 'unit'), ('TWO', 'couple')))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
print(wrapper.enum_from_alias('alpha')) # 1
print(wrapper.enum_from_alias('nothing')) # 0
# For the following, alias 'C' is not defined, so the method looks for
# a constant named 'C', and finds it.
print(wrapper.enum_from_alias('C')) # 3
.. note::
Unlike with constant names, you can **NOT** access constants values directly with aliases through
ParamsType (ie. you can't write ``wrapper.alpha``). You **must** use ``wrapper.enum_from_alias()``
method to do that.
"""
return self.__alias_to_enum[alias].fromalias(alias) if alias in self.__alias_to_enum else self.__const_to_enum[alias][alias]
def get_params(self, *objects, **kwargs):
"""
Convenient method to extract fields values from a list of Python objects and key-value args,
and wrap them into a :class:`Params` object compatible with current ParamsType.
For each field defined in the current ParamsType, a value for this field
is looked for in the given objects attributes (looking for attributes with this field name)
and key-values args (looking for a key equal to this field name), from left to right
(first object, then, ..., then last object, then key-value args), replacing a previous
field value found with any value found in next step, so that only the last field value
found is retained.
Fields values given in objects and kwargs must be compatible with types
associated to corresponding fields in current ParamsType.
**Example**::
import numpy
from theano.gof import ParamsType
from theano.tensor import dmatrix
from theano.scalar import Scalar
class MyObject:
def __init__(self):
self.a = 10
self.b = numpy.asarray([[1, 2, 3], [4, 5, 6]])
params_type = ParamsType(a=Scalar('int32'), b=dmatrix, c=Scalar('bool'))
o = MyObject()
value_for_c = False
# Value for c can't be retrieved from o, so we add a value for that field in kwargs.
params1 = params_type.get_params(o, c=value_for_c)
# params.a contains 10
# params.b contains [[1, 2, 3], [4, 5, 6]]
# params.c contains value_for_c
print(params)
"""
fields_values = dict()
# We collect fields values from given objects.
# If a field is present in many objects, only the field in the last object will be retained.
for obj in objects:
for field in self.fields:
try:
fields_values[field] = getattr(obj, field)
except Exception:
pass
# We then collect fields values from given kwargs.
# A field value in kwargs will replace any previous value collected from objects for this field.
for field in self.fields:
if field in kwargs:
fields_values[field] = kwargs[field]
# Then we filter the fields values and we create the Params object.
filtered = {self.fields[i]: self.types[i].filter(fields_values[self.fields[i]], strict=False, allow_downcast=True)
for i in range(self.length)}
return Params(self, **filtered)
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Params): if strict and not isinstance(data, Params):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论