提交 4025a2dc authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5755 from notoraptor/op-param-gpudnnsoftmax

Use Op params for GpuDnnSoftmax
...@@ -12,4 +12,5 @@ Reference ...@@ -12,4 +12,5 @@ Reference
:platform: Unix, Windows :platform: Unix, Windows
:synopsis: Wrapper class for op params :synopsis: Wrapper class for op params
:members: :members:
.. moduleauthor:: LISA :member-order: bysource
\ No newline at end of file .. moduleauthor:: LISA
...@@ -808,7 +808,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -808,7 +808,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
field = wrapper.fields[i] field = wrapper.fields[i]
_type = wrapper.types[i] _type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True) wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.Params(wrapper, **wrap_dict) return self.params_type.get_params(self)
raise theano.gof.utils.MethodNotDefined('get_params') raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
......
...@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``): ...@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py`` See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
for complete working examples. for complete working examples.
Combining ParamsType with Theano enumeration types
--------------------------------------------------
Theano provide some enumeration types that allow to create constant primitive values (integer and floating values)
available in both Python and C code. See :class:`theano.gof.type.EnumType` and its subclasses for more details.
If your ParamsType contains Theano enumeration types, then constants defined inside these
enumerations will be directly available as ParamsType attributes.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3'),
enum2=EnumType(PI=3.14, EPSILON=0.001))
# Each enum constant is available as a wrapper attribute:
print(wrapper.CONSTANT_1, wrapper.CONSTANT_2, wrapper.CONSTANT_3,
wrapper.PI, wrapper.EPSILON)
# For convenience, you can also look for a constant by name with
# ``ParamsType.get_enum()`` method.
pi = wrapper.get_enum('PI')
epsilon = wrapper.get_enum('EPSILON')
constant_2 = wrapper.get_enum('CONSTANT_2')
print(pi, epsilon, constant_2)
This implies that a ParamsType cannot contain different enum types with common enum names::
# Following line will raise an error,
# as there is a "CONSTANT_1" defined both in enum1 and enum2.
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2'),
enum2=EnumType(CONSTANT_1=0, CONSTANT_3=5))
If your enum types contain constant aliases, you can retrive them from ParamsType
with ``ParamsType.enum_from_alias(alias)`` method (see :class:`theano.gof.type.EnumType`
for more info about enumeration aliases).
.. code-block:: python
wrapper = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'),
enum2=EnumList(('D', 'delta'), 'E', 'F'))
b1 = wrapper.B
b2 = wrapper.get_enum('B')
b3 = wrapper.enum_from_alias('beta')
assert b1 == b2 == b3
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import re import re
import hashlib import hashlib
from theano.gof.utils import MethodNotDefined, c_cpp_keywords from theano.gof.utils import MethodNotDefined, c_cpp_keywords
from theano.gof import Type from theano.gof import Type, EnumType
class Params(dict): class Params(dict):
...@@ -193,6 +240,32 @@ class ParamsType(Type): ...@@ -193,6 +240,32 @@ class ParamsType(Type):
self.types = tuple(kwargs[field] for field in self.fields) self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name() self.name = self.generate_struct_name()
self.__const_to_enum = {}
self.__alias_to_enum = {}
enum_types = [t for t in self.types if isinstance(t, EnumType)]
if enum_types:
# We don't want same enum names in different enum types.
if sum(len(t) for t in enum_types) != len(set(k for t in enum_types for k in t)):
raise AttributeError('ParamsType: found different enum types with common constants names.')
# We don't want same aliases in different enum types.
if sum(len(t.aliases) for t in enum_types) != len(set(alias for t in enum_types for alias in t.aliases)):
raise AttributeError('ParamsType: found different enum types with common constants aliases.')
# We don't want aliases that have same names as some constants.
all_enums = {e for t in enum_types for e in t}
all_aliases = {a for t in enum_types for a in t.aliases}
if [a for a in all_aliases if a in all_enums]:
raise AttributeError('ParamsType: found aliases that have same names as constants.')
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in Wrapper object directly.
self.__const_to_enum = {enum_name: enum_type for enum_type in enum_types for enum_name in enum_type}
self.__alias_to_enum = {alias: enum_type for enum_type in enum_types for alias in enum_type.aliases}
def __getattr__(self, key):
# Now we can access value of each enum defined inside enum types wrapped into the current Wrapper.
if key in self.__const_to_enum:
return self.__const_to_enum[key][key]
return super(ParamsType, self).__getattr__(self, key)
def __repr__(self): def __repr__(self):
return 'ParamsType<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)]) return 'ParamsType<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)])
...@@ -213,6 +286,147 @@ class ParamsType(Type): ...@@ -213,6 +286,147 @@ class ParamsType(Type):
types_hex = hashlib.md5(types_string).hexdigest() types_hex = hashlib.md5(types_string).hexdigest()
return '_Params_%s_%s' % (fields_hex, types_hex) return '_Params_%s_%s' % (fields_hex, types_hex)
def has_type(self, theano_type):
"""
Return True if current ParamsType contains the specified Theano type.
"""
return theano_type in self.types
def get_field(self, theano_type):
"""
Return the name (string) of the first field associated to
the given Theano type. Fields are sorted in lexicographic
order. Raise an exception if this Theano type is not
in the current ParamsType.
This method is intended to be used to retrieve a field name
when we know that current ParamsType contains the given
Theano type only once.
"""
return self.fields[self.types.index(theano_type)]
def get_enum(self, key):
"""
Look for a constant named ``key`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either the constant is not found or
current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=1, B=2, C=3),
digits=EnumList('ZERO', 'ONE', 'TWO'))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
# You can also directly do:
print(wrapper.C)
print(wrapper.TWO)
"""
return self.__const_to_enum[key][key]
def enum_from_alias(self, alias):
"""
Look for a constant that has alias ``alias`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either
1. there is no constant with this alias,
2. there is no constant which name is this alias, or
3. current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=(1, 'alpha'), B=(2, 'beta'), C=3),
digits=EnumList(('ZERO', 'nothing'), ('ONE', 'unit'), ('TWO', 'couple')))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
print(wrapper.enum_from_alias('alpha')) # 1
print(wrapper.enum_from_alias('nothing')) # 0
# For the following, alias 'C' is not defined, so the method looks for
# a constant named 'C', and finds it.
print(wrapper.enum_from_alias('C')) # 3
.. note::
Unlike with constant names, you can **NOT** access constants values directly with aliases through
ParamsType (ie. you can't write ``wrapper.alpha``). You **must** use ``wrapper.enum_from_alias()``
method to do that.
"""
return self.__alias_to_enum[alias].fromalias(alias) if alias in self.__alias_to_enum else self.__const_to_enum[alias][alias]
def get_params(self, *objects, **kwargs):
"""
Convenient method to extract fields values from a list of Python objects and key-value args,
and wrap them into a :class:`Params` object compatible with current ParamsType.
For each field defined in the current ParamsType, a value for this field
is looked for in the given objects attributes (looking for attributes with this field name)
and key-values args (looking for a key equal to this field name), from left to right
(first object, then, ..., then last object, then key-value args), replacing a previous
field value found with any value found in next step, so that only the last field value
found is retained.
Fields values given in objects and kwargs must be compatible with types
associated to corresponding fields in current ParamsType.
**Example**::
import numpy
from theano.gof import ParamsType
from theano.tensor import dmatrix
from theano.scalar import Scalar
class MyObject:
def __init__(self):
self.a = 10
self.b = numpy.asarray([[1, 2, 3], [4, 5, 6]])
params_type = ParamsType(a=Scalar('int32'), b=dmatrix, c=Scalar('bool'))
o = MyObject()
value_for_c = False
# Value for c can't be retrieved from o, so we add a value for that field in kwargs.
params = params_type.get_params(o, c=value_for_c)
# params.a contains 10
# params.b contains [[1, 2, 3], [4, 5, 6]]
# params.c contains value_for_c
print(params)
"""
fields_values = dict()
# We collect fields values from given objects.
# If a field is present in many objects, only the field in the last object will be retained.
for obj in objects:
for field in self.fields:
try:
fields_values[field] = getattr(obj, field)
except Exception:
pass
# We then collect fields values from given kwargs.
# A field value in kwargs will replace any previous value collected from objects for this field.
for field in self.fields:
if field in kwargs:
fields_values[field] = kwargs[field]
# Then we filter the fields values and we create the Params object.
filtered = {self.fields[i]: self.types[i].filter(fields_values[self.fields[i]], strict=False, allow_downcast=True)
for i in range(self.length)}
return Params(self, **filtered)
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Params): if strict and not isinstance(data, Params):
...@@ -309,12 +523,18 @@ class ParamsType(Type): ...@@ -309,12 +523,18 @@ class ParamsType(Type):
sub = {'fail': '{this->setErrorOccurred(); return;}'} sub = {'fail': '{this->setErrorOccurred(); return;}'}
struct_name = self.name struct_name = self.name
struct_name_defined = struct_name.upper() struct_name_defined = struct_name.upper()
c_support_code_set = set()
c_declare_list = [] c_declare_list = []
c_init_list = [] c_init_list = []
c_cleanup_list = [] c_cleanup_list = []
c_extract_list = [] c_extract_list = []
for attribute_name, type_instance in zip(self.fields, self.types): for attribute_name, type_instance in zip(self.fields, self.types):
try:
c_support_code_set.add(type_instance.c_support_code())
except MethodNotDefined:
pass
c_declare_list.append(type_instance.c_declare(attribute_name, sub)) c_declare_list.append(type_instance.c_declare(attribute_name, sub))
c_init_list.append(type_instance.c_init(attribute_name, sub)) c_init_list.append(type_instance.c_init(attribute_name, sub))
...@@ -330,6 +550,7 @@ class ParamsType(Type): ...@@ -330,6 +550,7 @@ class ParamsType(Type):
'extract_code': type_instance.c_extract(attribute_name, sub) 'extract_code': type_instance.c_extract(attribute_name, sub)
}) })
support_code = '\n'.join(sorted(list(c_support_code_set)))
struct_declare = '\n'.join(c_declare_list) struct_declare = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list) struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list) struct_cleanup = '\n'.join(c_cleanup_list)
...@@ -350,6 +571,7 @@ class ParamsType(Type): ...@@ -350,6 +571,7 @@ class ParamsType(Type):
[('case %d: extract_%s(object); break;' % (i, self.fields[i])) for i in range(self.length)]) [('case %d: extract_%s(object); break;' % (i, self.fields[i])) for i in range(self.length)])
) )
return """ return """
%(support_code)s
#ifndef %(struct_name_defined)s #ifndef %(struct_name_defined)s
#define %(struct_name_defined)s #define %(struct_name_defined)s
struct %(struct_name)s { struct %(struct_name)s {
...@@ -389,12 +611,13 @@ class ParamsType(Type): ...@@ -389,12 +611,13 @@ class ParamsType(Type):
} }
}; };
#endif #endif
""" % dict(struct_name_defined=struct_name_defined, struct_name=struct_name, struct_declare=struct_declare, """ % dict(support_code=support_code,
struct_name_defined=struct_name_defined, struct_name=struct_name, struct_declare=struct_declare,
struct_init=struct_init, struct_cleanup=struct_cleanup, struct_extract=struct_extract, struct_init=struct_init, struct_cleanup=struct_cleanup, struct_extract=struct_extract,
struct_extract_method=struct_extract_method) struct_extract_method=struct_extract_method)
def c_code_cache_version(self): def c_code_cache_version(self):
return ((1, 7), tuple(t.c_code_cache_version() for t in self.types)) return ((1, 8), tuple(t.c_code_cache_version() for t in self.types))
# As this struct has constructor and destructor, it could be instanciated on stack, # As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions, # but current implementations of C ops will then pass the instance by value at functions,
......
...@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply ...@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply
from theano import Generic from theano import Generic
from theano.scalar import Scalar from theano.scalar import Scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof import ParamsType, Params from theano.gof import ParamsType, Params, EnumList
from theano import tensor from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -213,6 +213,35 @@ class TestParamsType(TestCase): ...@@ -213,6 +213,35 @@ class TestParamsType(TestCase):
a3=2000.0 - 0.00000000000000001) a3=2000.0 - 0.00000000000000001)
assert w.values_eq_approx(o1, o3) assert w.values_eq_approx(o1, o3)
def test_params_type_with_enums(self):
# Test that we fail if we create a params type with common enum names inside different enum types.
try:
ParamsType(enum1=EnumList('A', 'B', 'C'), enum2=EnumList('A', 'B', 'F'))
except AttributeError:
pass
else:
raise Exception('ParamsType should fail with common enum names inside different enum types.')
# Test that we fail if we create a params type with common names in both aliases and constants.
try:
ParamsType(enum1=EnumList(('A', 'a'), ('B', 'b')), enum2=EnumList(('ONE', 'a'), ('TWO', 'two')))
except AttributeError:
ParamsType(enum1=EnumList(('A', 'a'), ('B', 'b')), enum2=EnumList(('ONE', 'one'), ('TWO', 'two')))
else:
raise Exception('ParamsType should fail when there are aliases with same names as some constants.')
# Test that we can access enum values through wrapper directly.
w = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'), enum2=EnumList(('D', 'delta'), 'E', 'F'))
assert w.A == 0 and w.B == 1 and w.C == 2
assert w.D == 0 and w.E == 1 and w.F == 2
# Test constants access through aliases.
assert w.enum_from_alias('beta') == w.B
assert w.enum_from_alias('delta') == w.D
assert w.enum_from_alias('C') == w.C # C is not an alias, so it should return a constant named C.
# Test that other regular wrapper attributes are still available.
assert len(w.fields) == len(w.types) == w.length
assert w.name
def test_op_params(self): def test_op_params(self):
a, b, c = 2, 3, -7 a, b, c = 2, 3, -7
x = tensor.matrix(dtype='float64') x = tensor.matrix(dtype='float64')
......
...@@ -81,18 +81,19 @@ def test_cdata(): ...@@ -81,18 +81,19 @@ def test_cdata():
class MyOpEnumList(Op): class MyOpEnumList(Op):
__props__ = ('op_chosen',) __props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long') params_type = EnumList(('ADD', '+'), ('SUB', '-'), ('MULTIPLY', '*'), ('DIVIDE', '/'), ctype='unsigned long long')
def __init__(self, choose_op): def __init__(self, choose_op):
assert self.params_type.ADD == 0 assert self.params_type.ADD == 0
assert self.params_type.SUB == 1 assert self.params_type.SUB == 1
assert self.params_type.MULTIPLY == 2 assert self.params_type.MULTIPLY == 2
assert self.params_type.DIVIDE == 3 assert self.params_type.DIVIDE == 3
op_to_const = {'+': self.params_type.ADD, assert self.params_type.fromalias('+') == self.params_type.ADD
'-': self.params_type.SUB, assert self.params_type.fromalias('-') == self.params_type.SUB
'*': self.params_type.MULTIPLY, assert self.params_type.fromalias('*') == self.params_type.MULTIPLY
'/': self.params_type.DIVIDE} assert self.params_type.fromalias('/') == self.params_type.DIVIDE
self.op_chosen = op_to_const[choose_op] assert self.params_type.has_alias(choose_op)
self.op_chosen = choose_op
def get_params(self, node): def get_params(self, node):
return self.op_chosen return self.op_chosen
...@@ -204,7 +205,7 @@ class TestEnumTypes(TestCase): ...@@ -204,7 +205,7 @@ class TestEnumTypes(TestCase):
# Check that invalid enum value raises exception. # Check that invalid enum value raises exception.
try: try:
EnumType(INVALID_VALUE='string is not allowed.') EnumType(INVALID_VALUE='string is not allowed.')
except ValueError: except TypeError:
pass pass
else: else:
raise Exception('EnumType with invalid value should fail.') raise Exception('EnumType with invalid value should fail.')
...@@ -218,6 +219,23 @@ class TestEnumTypes(TestCase): ...@@ -218,6 +219,23 @@ class TestEnumTypes(TestCase):
# Check access to attributes. # Check access to attributes.
assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7 assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7
# Check enum with aliases.
e1 = EnumType(A=('alpha', 0), B=('beta', 1), C=2)
e2 = EnumType(A=('alpha', 0), B=('beta', 1), C=2)
e3 = EnumType(A=('a', 0), B=('beta', 1), C=2)
assert e1 == e2
assert e1 != e3
assert e1.filter('beta') == e1.fromalias('beta') == e1.B == 1
assert e1.filter('C') == e1.fromalias('C') == e1.C == 2
# Check that invalid alias (same as a constant) raises exception.
try:
EnumList(('A', 'a'), ('B', 'B'))
except TypeError:
EnumList(('A', 'a'), ('B', 'b'))
else:
raise Exception('Enum with an alias name equal to a constant name should fail.')
def test_op_with_enumlist(self): def test_op_with_enumlist(self):
a = scalar.int32() a = scalar.int32()
b = scalar.int32() b = scalar.int32()
......
...@@ -814,13 +814,20 @@ CDataType.Constant = CDataTypeConstant ...@@ -814,13 +814,20 @@ CDataType.Constant = CDataTypeConstant
class EnumType(Type, dict): class EnumType(Type, dict):
""" """
Main subclasses:
- :class:`EnumList`
- :class:`CEnumType`
Op parameter class that allows to create enumerations of constant values. Op parameter class that allows to create enumerations of constant values.
- Constants are available as object attributes in Python code and as macro-defined constants in C code. - Constants are available as object attributes in Python code and as macro-defined constants in C code.
- Constants can be floating values, integers, or booleans (automatically converted to integers). - Constants can be floating values, integers, or booleans (automatically converted to integers).
- Constants name must start with a capital letter and contain capital letters, underscores or digits. - Constants name must start with a capital letter and contain capital letters, underscores or digits.
- A constant can have an alias, and then be available through both constant name and constant alias.
Example:: **Example**
.. code-block:: python
enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True) enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4) print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4)
...@@ -849,35 +856,111 @@ class EnumType(Type, dict): ...@@ -849,35 +856,111 @@ class EnumType(Type, dict):
size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0 size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0
**Example with aliases**
When creating an enum, you can give some aliases to specific constants while keeping other constants without aliases.
An alias must be a string, and there is currently no string format constraints.
To give an alias to a constant in the EnumType constructor, use the following key-value syntax::
constant_name=(constant_alias, constant_value)
You can then retrieve a constant from an alias with method ``EnumType.fromalias()``.
Aliases are intended to be used in Python code only (only constants names are available in C code).
Especially, an alias will be recognized by ``Enumtype.filter()`` method with non-strict filtering,
allowing a maximum flexibility for converting strings to numeric constants available in Python and C code.
.. code-block:: python
from theano.gof import EnumType
# You can remark that constant 'C' does not have an alias.
enum = EnumType(A=('alpha', 1), B=('beta', 2), C=3, D=('delta', 4))
# Constants are all directly available by name.
print(enum.A, enum.B, enum.C, enum.D)
# But we can also now get some constants by alias.
a = enum.fromalias('alpha')
b = enum.fromalias('beta')
d = enum.fromalias('delta')
# If method fromalias() receives an unknown alias,
# it will looks for a constant with this alias
# as exact constant name.
c = enum.fromalias('C') # will get enum.C
# An alias defined in an EnumType will be correctly converted with non-strict filtering.
value = enum.filter('delta', strict=False)
# value now contaisn enum.D, ie. 4.
.. note:: .. note::
This Type (and subclasses) is not complete and should never be used for regular graph operations. This Type (and subclasses) is not complete and should never be used for regular graph operations.
""" """
def check_ctype(self): def __init_ctype(self, ctype):
# C type may be a list of keywords, e.g. "unsigned long long". # C type may be a list of keywords, e.g. "unsigned long long".
# We should check each part. # We should check each part.
if not all(re.match('^[A-Za-z_][A-Za-z0-9_]*$', el) for el in self.ctype.split()): ctype_parts = ctype.split()
raise TypeError('%s: invalid C type' % type(self).__name__) if not all(re.match('^[A-Za-z_][A-Za-z0-9_]*$', el) for el in ctype_parts):
raise TypeError('%s: invalid C type.' % type(self).__name__)
self.ctype = ' '.join(ctype_parts)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.ctype = kwargs.pop('ctype', 'double') self.__init_ctype(kwargs.pop('ctype', 'double'))
self.check_ctype() self.aliases = dict()
for k in kwargs: for k in kwargs:
if re.match('^[A-Z][A-Z0-9_]*$', k) is None: if re.match('^[A-Z][A-Z0-9_]*$', k) is None:
raise AttributeError('%s: invalid enum name: "%s". ' raise AttributeError('%s: invalid enum name: "%s". '
'Only capital letters, underscores and digits ' 'Only capital letters, underscores and digits '
'are allowed.' % (type(self).__name__, k)) 'are allowed.' % (type(self).__name__, k))
if isinstance(kwargs[k], (list, tuple)):
if len(kwargs[k]) != 2:
raise TypeError('%s: when using a tuple to define a constant, your tuple should contain 2 values: '
'constant alias followed by constant value.' % type(self).__name__)
alias, value = kwargs[k]
if not isinstance(alias, str):
raise TypeError('%s: constant alias should be a string, got "%s".'
% (type(self).__name__, alias))
if alias == k:
raise TypeError("%s: it's useless to create an alias "
"with the same name as its associated constant." % type(self).__name__)
if alias in self.aliases:
raise TypeError('%s: consant alias "%s" already used.' % (type(self).__name__, alias))
self.aliases[alias] = k
kwargs[k] = value
if isinstance(kwargs[k], bool): if isinstance(kwargs[k], bool):
kwargs[k] = int(kwargs[k]) kwargs[k] = int(kwargs[k])
elif not isinstance(kwargs[k], (int, float)): elif not isinstance(kwargs[k], (int, float)):
raise ValueError('%s: constant "%s": expected integer or floating value, got "%s".' raise TypeError('%s: constant "%s": expected integer or floating value, got "%s".'
% (type(self).__name__, k, type(kwargs[k]).__name__)) % (type(self).__name__, k, type(kwargs[k]).__name__))
if [a for a in self.aliases if a in self]:
raise TypeError("%s: some aliases have same names as constants." % type(self).__name__)
super(EnumType, self).__init__(**kwargs) super(EnumType, self).__init__(**kwargs)
def fromalias(self, alias):
"""
Get a constant value by its alias.
If there is not such alias in this enum, look for a constant
with this alias as constant name.
"""
return self[self.aliases[alias]] if alias in self.aliases else self[alias]
def has_alias(self, alias):
"""
return True if and only if this enum has this alias.
"""
return alias in self.aliases
def __repr__(self): def __repr__(self):
return '%s(%s)' % (type(self).__name__, ', '.join('%s:%s' % (k, self[k]) for k in sorted(self.keys()))) names_to_aliases = {constant_name: '' for constant_name in self}
for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = '(%s)' % alias
return '%s<%s>(%s)' % (type(self).__name__, self.ctype,
', '.join('%s%s:%s' % (k, names_to_aliases[k], self[k]) for k in sorted(self.keys())))
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:
...@@ -897,14 +980,19 @@ class EnumType(Type, dict): ...@@ -897,14 +980,19 @@ class EnumType(Type, dict):
def __hash__(self): def __hash__(self):
# All values are Python basic types, then easy to hash. # All values are Python basic types, then easy to hash.
return hash((type(self), self.ctype) + tuple((k, self[k]) for k in sorted(self.keys()))) return hash((type(self), self.ctype) +
tuple((k, self[k]) for k in sorted(self.keys())) +
tuple((a, self.aliases[a]) for a in sorted(self.aliases.keys())))
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
self.ctype == other.ctype and self.ctype == other.ctype and
len(self) == len(other) and len(self) == len(other) and
len(self.aliases) == len(other.aliases) and
all(k in other for k in self) and all(k in other for k in self) and
all(self[k] == other[k] for k in self)) all(a in other.aliases for a in self.aliases) and
all(self[k] == other[k] for k in self) and
all(self.aliases[a] == other.aliases[a] for a in self.aliases))
# EnumType should be used to create constants available in both Python and C code. # EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types, # However, for convenience, we make sure EnumType can have a value, like other common types,
...@@ -912,8 +1000,13 @@ class EnumType(Type, dict): ...@@ -912,8 +1000,13 @@ class EnumType(Type, dict):
# C type of value is defined in self.ctype. # C type of value is defined in self.ctype.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if not strict and isinstance(data, bool): if not strict:
data = int(data) if isinstance(data, bool):
data = int(data)
elif isinstance(data, str):
# We now accept strings as data values.
# Strings should be a constant alias or a constant name.
data = self.fromalias(data)
assert data in self.values() assert data in self.values()
return data return data
...@@ -947,7 +1040,7 @@ class EnumType(Type, dict): ...@@ -947,7 +1040,7 @@ class EnumType(Type, dict):
return """%(ctype)s %(name)s;""" % dict(ctype=self.ctype, name=name) return """%(ctype)s %(name)s;""" % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub): def c_init(self, name, sub):
return "%(name)s = 0;" % dict(name=name) return "%(name)s = (%(ctype)s)0;" % dict(name=name, ctype=self.ctype)
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
...@@ -965,11 +1058,14 @@ class EnumType(Type, dict): ...@@ -965,11 +1058,14 @@ class EnumType(Type, dict):
""" % dict(ctype=self.ctype, name=name, fail=sub['fail']) """ % dict(ctype=self.ctype, name=name, fail=sub['fail'])
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1, 1)
class EnumList(EnumType): class EnumList(EnumType):
""" """
**Inherit from**:
- :class:`EnumType`
Op parameter class that allows to create enumeration of constant values. Op parameter class that allows to create enumeration of constant values.
Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of
constants names (constant at index ``i`` in the list will receive value ``i``, constants names (constant at index ``i`` in the list will receive value ``i``,
...@@ -986,6 +1082,14 @@ class EnumList(EnumType): ...@@ -986,6 +1082,14 @@ class EnumList(EnumType):
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int') enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int')
Like :class:`EnumType`, you can also add an alias to a constant, by replacing the only constant name
(e.g. ``'CONSTANT_NAME'``) by a couple with constant name first and constant alias second
(e.g. ``('CONSTANT_NAME', 'constant_alias')``).
.. code-block:: python
enum = EnumList(('A', 'alpha'), ('B', 'beta'), 'C', 'D', 'E', 'F', ('G', 'gamma'))
See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example. See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example.
""" """
...@@ -995,16 +1099,35 @@ class EnumList(EnumType): ...@@ -995,16 +1099,35 @@ class EnumList(EnumType):
type(self).__name__ + ': expected 0 or only 1 extra parameter "ctype".' type(self).__name__ + ': expected 0 or only 1 extra parameter "ctype".'
ctype = kwargs.pop('ctype', 'int') ctype = kwargs.pop('ctype', 'int')
if len(args) > len(set(args)): for arg_rank, arg in enumerate(args):
raise AttributeError(type(self).__name__ + ': some constants names are duplicated.') if isinstance(arg, (list, tuple)):
if len(arg) != 2:
raise TypeError('%s: when using a tuple to define a constant, your tuple should contain 2 values: '
'constant name followed by constant alias.' % type(self).__name__)
constant_name, constant_alias = arg
if not isinstance(constant_alias, str):
raise TypeError('%s: constant alias should be a string, got "%s".'
% (type(self).__name__, constant_alias))
constant_value = (constant_alias, arg_rank)
else:
constant_name = arg
constant_value = arg_rank
if not isinstance(constant_name, str):
raise TypeError('%s: constant name should be a string, got "%s".'
% (type(self).__name__, constant_name))
if constant_name in kwargs:
raise TypeError('%s: constant name already used ("%s").' % (type(self).__name__, constant_name))
kwargs[constant_name] = constant_value
kwargs = {const_name: const_rank for (const_rank, const_name) in enumerate(args)}
kwargs.update(ctype=ctype) kwargs.update(ctype=ctype)
super(EnumList, self).__init__(**kwargs) super(EnumList, self).__init__(**kwargs)
class CEnumType(EnumList): class CEnumType(EnumList):
""" """
**Inherit from**:
- :class:`EnumList`
Op parameter class that allows to create enumeration of constant values that represent C-defined constants. Op parameter class that allows to create enumeration of constant values that represent C-defined constants.
- Constant should have same names as in C. - Constant should have same names as in C.
...@@ -1020,6 +1143,8 @@ class CEnumType(EnumList): ...@@ -1020,6 +1143,8 @@ class CEnumType(EnumList):
enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long') enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long')
Like :class:`EnumList`, you can also add an alias to a constant, with same syntax as in :class:`EnumList`.
See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example. See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example.
.. note:: .. note::
......
...@@ -12,7 +12,7 @@ from theano import Op, Apply, tensor, config, Variable ...@@ -12,7 +12,7 @@ from theano import Op, Apply, tensor, config, Variable
from theano.scalar import as_scalar, constant, Log, get_scalar_type from theano.scalar import as_scalar, constant, Log, get_scalar_type
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable
from theano.gradient import DisconnectedType, grad_not_implemented from theano.gradient import DisconnectedType, grad_not_implemented
from theano.gof import Optimizer, local_optimizer, COp from theano.gof import Optimizer, local_optimizer, COp, ParamsType, CEnumType
from theano.gof.cmodule import GCC_compiler from theano.gof.cmodule import GCC_compiler
from theano.gof.type import CDataType, Generic from theano.gof.type import CDataType, Generic
from theano.compile import optdb from theano.compile import optdb
...@@ -234,6 +234,11 @@ class DnnBase(COp): ...@@ -234,6 +234,11 @@ class DnnBase(COp):
ptr = ctx.cudnn_handle.value ptr = ctx.cudnn_handle.value
res = handle_type.make_value(ptr) res = handle_type.make_value(ptr)
ctx.cudnn_handle_param = res ctx.cudnn_handle_param = res
if isinstance(self.params_type, ParamsType):
if not self.params_type.has_type(handle_type):
raise TypeError('DnnBase: params_type must take into account the cuDNN handle type.')
handle_field = self.params_type.get_field(handle_type)
return self.params_type.get_params(self, **{handle_field: ctx.cudnn_handle_param})
return ctx.cudnn_handle_param return ctx.cudnn_handle_param
def __init__(self, files=None, c_func=None): def __init__(self, files=None, c_func=None):
...@@ -1504,6 +1509,18 @@ class GpuDnnSoftmaxBase(DnnBase): ...@@ -1504,6 +1509,18 @@ class GpuDnnSoftmaxBase(DnnBase):
""" """
__props__ = ('mode', 'algo') __props__ = ('mode', 'algo')
# Neither inputs nor output types properties are used
# neither in dnn_base.c nor in dnn_softmax*.c,
# so we can disable input checking.
check_input = False
params_type = ParamsType(algo=CEnumType(('CUDNN_SOFTMAX_FAST', 'fast'),
('CUDNN_SOFTMAX_LOG', 'log'),
('CUDNN_SOFTMAX_ACCURATE', 'accurate'),
ctype='cudnnSoftmaxAlgorithm_t'),
mode=CEnumType(('CUDNN_SOFTMAX_MODE_INSTANCE', 'instance'),
('CUDNN_SOFTMAX_MODE_CHANNEL', 'channel'),
ctype='cudnnSoftmaxMode_t'),
handle=handle_type)
def __init__(self, algo, mode): def __init__(self, algo, mode):
DnnBase.__init__(self, [self.file], self.c_func) DnnBase.__init__(self, [self.file], self.c_func)
...@@ -1520,21 +1537,6 @@ class GpuDnnSoftmaxBase(DnnBase): ...@@ -1520,21 +1537,6 @@ class GpuDnnSoftmaxBase(DnnBase):
else: else:
return [shape[1]] return [shape[1]]
def get_op_params(self):
if self.mode == 'instance':
mode = "CUDNN_SOFTMAX_MODE_INSTANCE"
else:
mode = "CUDNN_SOFTMAX_MODE_CHANNEL"
if self.algo == 'fast':
algo = "CUDNN_SOFTMAX_FAST"
elif self.algo == 'log':
algo = "CUDNN_SOFTMAX_LOG"
else:
algo = "CUDNN_SOFTMAX_ACCURATE"
return [("SOFTMAX_MODE", mode), ("SOFTMAX_ALGO", algo)]
class GpuDnnSoftmax(GpuDnnSoftmaxBase): class GpuDnnSoftmax(GpuDnnSoftmaxBase):
......
...@@ -35,7 +35,7 @@ if (APPLY_SPECIFIC(output) != NULL) ...@@ -35,7 +35,7 @@ if (APPLY_SPECIFIC(output) != NULL)
int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x, int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
PyGpuArrayObject **out, PyGpuArrayObject **out,
cudnnHandle_t _handle) { PARAMS_TYPE* wrapper) {
PyGpuContextObject *c = x->context; PyGpuContextObject *c = x->context;
cudnnStatus_t err; cudnnStatus_t err;
...@@ -83,9 +83,9 @@ int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x, ...@@ -83,9 +83,9 @@ int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
cuda_wait((*out)->ga.data, GPUARRAY_CUDA_WAIT_WRITE); cuda_wait((*out)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
err = cudnnSoftmaxForward( err = cudnnSoftmaxForward(
_handle, wrapper->handle,
SOFTMAX_ALGO, wrapper->algo,
SOFTMAX_MODE, wrapper->mode,
alpha, alpha,
APPLY_SPECIFIC(input), APPLY_SPECIFIC(input),
PyGpuArray_DEV_DATA(x), PyGpuArray_DEV_DATA(x),
......
...@@ -46,7 +46,7 @@ if (APPLY_SPECIFIC(dx) != NULL) ...@@ -46,7 +46,7 @@ if (APPLY_SPECIFIC(dx) != NULL)
int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy, int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
PyGpuArrayObject *sm, PyGpuArrayObject *sm,
PyGpuArrayObject **dx, PyGpuArrayObject **dx,
cudnnHandle_t _handle) { PARAMS_TYPE* wrapper) {
PyGpuContextObject *c = dy->context; PyGpuContextObject *c = dy->context;
cudnnStatus_t err; cudnnStatus_t err;
...@@ -97,9 +97,9 @@ int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy, ...@@ -97,9 +97,9 @@ int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
cuda_wait((*dx)->ga.data, GPUARRAY_CUDA_WAIT_WRITE); cuda_wait((*dx)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
err = cudnnSoftmaxBackward( err = cudnnSoftmaxBackward(
_handle, wrapper->handle,
SOFTMAX_ALGO, wrapper->algo,
SOFTMAX_MODE, wrapper->mode,
alpha, alpha,
APPLY_SPECIFIC(sm), APPLY_SPECIFIC(sm),
PyGpuArray_DEV_DATA(sm), PyGpuArray_DEV_DATA(sm),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论