提交 723addd1 authored 作者: notoraptor's avatar notoraptor

Add new abstract attribute `enum` in COp

that will be used by get_op_params() to auto-generate macros for the op. Add two methods to help create enums quickly. Add a test for this new feature. Simplify code. Add new op param types `EnumType` and `EnumList`. This replaces the previous implementation. C macros are created into `EnumType.c_support_code()`. Examples added in gof/tests/test_types. Flake8.
上级 8e292493
...@@ -55,6 +55,8 @@ from theano.gof.link import \ ...@@ -55,6 +55,8 @@ from theano.gof.link import \
from theano.gof.op import \ from theano.gof.op import \
Op, OpenMPOp, PureOp, COp, ops_with_inner_function Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.type import EnumType, EnumList
from theano.gof.opt import ( from theano.gof.opt import (
Optimizer, Optimizer,
optimizer, inplace_optimizer, optimizer, inplace_optimizer,
......
...@@ -2,9 +2,9 @@ from __future__ import absolute_import, print_function, division ...@@ -2,9 +2,9 @@ from __future__ import absolute_import, print_function, division
import numpy as np import numpy as np
import theano import theano
from theano import Op, Apply from theano import Op, Apply, scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof.type import CDataType from theano.gof.type import CDataType, EnumType, EnumList
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -76,3 +76,118 @@ def test_cdata(): ...@@ -76,3 +76,118 @@ def test_cdata():
v2 = f(v) v2 = f(v)
assert (v2 == v).all() assert (v2 == v).all()
class TestOpEnumList(Op):
__props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE')
def __init__(self, choose_op):
assert self.params_type.ADD == 0
assert self.params_type.SUB == 1
assert self.params_type.MULTIPLY == 2
assert self.params_type.DIVIDE == 3
op_to_const = {'+': self.params_type.ADD,
'-': self.params_type.SUB,
'*': self.params_type.MULTIPLY,
'/': self.params_type.DIVIDE}
self.op_chosen = op_to_const[choose_op]
def get_params(self, node):
return self.op_chosen
def make_node(self, a, b):
return Apply(self, [scalar.as_scalar(a), scalar.as_scalar(b)], [scalar.float64()])
def perform(self, node, inputs, outputs, op):
a, b = inputs
o, = outputs
if op == self.params_type.ADD:
o[0] = a + b
elif op == self.params_type.SUB:
o[0] = a - b
elif op == self.params_type.MULTIPLY:
o[0] = a * b
elif op == self.params_type.DIVIDE:
o[0] = a / b
else:
raise NotImplementedError('Unknown op id ' + str(op))
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
a, b = inputs
o, = outputs
fail = sub['fail']
op = sub['params']
return """
switch((int)%(op)s) {
case ADD:
%(o)s = %(a)s + %(b)s;
break;
case SUB:
%(o)s = %(a)s - %(b)s;
break;
case MULTIPLY:
%(o)s = %(a)s * %(b)s;
break;
case DIVIDE:
%(o)s = %(a)s / %(b)s;
break;
default:
{%(fail)s}
break;
}
""" % locals()
def test_enum():
# Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'):
try:
EnumList(invalid_name)
except AttributeError:
pass
else:
raise Exception('EnumList with invalid name should faild.')
try:
EnumType(**{invalid_name: 0})
except AttributeError:
pass
else:
raise Exception('EnumType with invalid name should fail.')
# Check that invalid enum value raises exception.
try:
EnumType(INVALID_VALUE='string is not allowe.')
except ValueError:
pass
else:
raise Exception('EnumType with invalid value should fail.')
# Check EnumType.
e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
assert e1 == e2
assert not (e1 != e2)
assert hash(e1) == hash(e2)
# Test an op with EnumList.
a = scalar.int32()
b = scalar.int32()
c_add = TestOpEnumList('+')(a, b)
c_sub = TestOpEnumList('-')(a, b)
c_multiply = TestOpEnumList('*')(a, b)
c_divide = TestOpEnumList('/')(a, b)
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12
vb = 15
ref_add = va + vb
ref_sub = va - vb
ref_multiply = va * vb
ref_divide = va // vb
ref = [ref_add, ref_sub, ref_multiply, ref_divide]
out = f(va, vb)
assert ref == out, (ref, out)
...@@ -10,6 +10,7 @@ import ctypes ...@@ -10,6 +10,7 @@ import ctypes
from six import string_types from six import string_types
import re
import theano import theano
from theano.gof import utils from theano.gof import utils
from theano.gof.utils import MethodNotDefined, object2 from theano.gof.utils import MethodNotDefined, object2
...@@ -808,3 +809,137 @@ class CDataTypeConstant(graph.Constant): ...@@ -808,3 +809,137 @@ class CDataTypeConstant(graph.Constant):
return (self.type,) return (self.type,)
CDataType.Constant = CDataTypeConstant CDataType.Constant = CDataTypeConstant
class EnumType(Type, dict):
"""
Class that allows to create enumerations of constant values.
Constants are available as object attributes in Python code and as macro-defined constants in C code.
Constants can be floating values, integers, or booleans (automatically converted to integers).
Constants name must start with a capital letter and contain capital letters, underscores or digits.
This type is intended to be used as op parameter type.
Example::
enum = EnumType(CONSTANT_1=0, CONSTANT_2=1, CONSTANT_3=2.5, CONSTANT_4=False, CONSTANT_5=True)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5)
In C code:
.. code-block:: c
int constant_1 = CONSTANT_1;
int constant_2 = CONSTANT_2;
double constant_3 = CONSTANT_3;
int constant_4 = CONSTANT_4; // constant_4 == 0
int constant_5 = CONSTANT_5; // constant_5 == 1
..note::
This Type is not complete and should never be used for regular graph operations.
"""
def __init__(self, **kwargs):
for k in kwargs:
if re.match('^[A-Z][A-Z0-9_]*$', k) is None:
raise AttributeError('EnumType: invalid enum name: "%s". '
'Only capital letters, underscores and digits '
'are allowed.' % k)
if isinstance(kwargs[k], bool):
kwargs[k] = int(kwargs[k])
elif not isinstance(kwargs[k], (int, float)):
raise ValueError('EnumType: enum "%s": expected integer or floating value, got "%s".'
% (k, type(kwargs[k]).__name__))
super(EnumType, self).__init__(**kwargs)
def __repr__(self):
return 'EnumType(%s)' % ', '.join('%s:%s' % (k, self[k]) for k in sorted(self.keys()))
def __getattr__(self, key):
if key in self:
return self[key]
return Type.__getattr__(self, key)
def __setattr__(self, key, value):
if key in self:
raise NotImplementedError('EnumType values are immutable.')
Type.__setattr__(self, key, value)
def __setitem__(self, key, value):
raise NotImplementedError('EnumType values are immutable.')
def __delitem__(self, key):
raise NotImplementedError('EnumType values are immutable.')
def __hash__(self):
# All values are Python basic types, then easy to hash.
return hash((type(self),) + tuple((k, self[k]) for k in sorted(self.keys())))
def __eq__(self, other):
return (type(self) == type(other) and
len(self) == len(other) and
all(k in other for k in self) and
all(self[k] == other[k] for k in self))
# EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types,
# such that it could be used as-is as an op param.
# As we currently allow enum constants to be booleans, integers or floating values,
# we choose the biggest basic type (i.e. float) as type of enum values.
def filter(self, data, strict=False, allow_downcast=None):
if strict:
assert isinstance(data, float)
return data
assert isinstance(data, (bool, int, float))
return float(data)
def values_eq(self, a, b):
return a == b
def values_eq_approx(self, a, b):
return float(a) == float(b)
def c_support_code(self):
return ''.join("""
#define %s %s
""" % (k, str(self[k])) for k in sorted(self.keys()))
def c_declare(self, name, sub, check_input=True):
return """double %(name)s;""" % locals()
def c_init(self, name, sub):
return "%(name)s = 0;" % locals()
def c_cleanup(self, name, sub):
return ""
def c_extract(self, name, sub, check_input=True):
return """
%(name)s = PyFloat_AsDouble(py_%(name)s);
if (PyErr_Occurred()) {
%(name)s = 0;
}
""" % locals()
class EnumList(EnumType):
""""
Class that allows to create enumeration of constant integer values.
Same as :class:`EnumType`, but automatically gives an unique integer value to each constant in a list of
constants names (constant at index i in the list will receive value i, i from ``0`` to ``len(constants)-1``).
Example::
enum = EnumList(CONSTANT_1, CONSTANT_2, CONSTANT_3, CONSTANT_4, CONSTANT_5)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5)
# will print: 0 1 2 3 4
"""
def __init__(self, *args):
if len(args) > len(set(args)):
raise AttributeError('EnumList: some constants names are duplicated.')
super(EnumList, self).__init__(**{const_name: const_rank for (const_rank, const_name) in enumerate(args)})
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论