Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fd552f23
提交
fd552f23
authored
3月 02, 2017
作者:
notoraptor
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add a COp test in test_wrapper.
Rewrite `Wrap` so that it depends on a Wrapper. Simplify code.
上级
5260c149
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
246 行增加
和
146 行删除
+246
-146
op.py
theano/gof/op.py
+9
-16
test_quadratic_function.c
theano/gof/tests/test_quadratic_function.c
+44
-0
test_wrapper.py
theano/gof/tests/test_wrapper.py
+63
-30
wrapper.py
theano/gof/wrapper.py
+130
-100
没有找到文件。
theano/gof/op.py
浏览文件 @
fd552f23
...
...
@@ -799,22 +799,15 @@ class Op(utils.object2, PureOp, CLinkerOp):
# We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception.
def
get_params
(
self
,
node
):
if
hasattr
(
self
,
'params_type'
):
# If params_type is a Wrapper, we try to extract params from the op.
if
isinstance
(
self
.
params_type
,
theano
.
gof
.
wrapper
.
Wrapper
):
wrapper
=
self
.
params_type
op_has_wrap_attributes
=
True
for
field
in
wrapper
.
fields
:
if
not
hasattr
(
self
,
field
):
op_has_wrap_attributes
=
False
break
if
op_has_wrap_attributes
:
wrap_dict
=
dict
()
for
i
in
range
(
wrapper
.
length
):
field
=
wrapper
.
fields
[
i
]
_type
=
wrapper
.
types
[
i
]
wrap_dict
[
field
]
=
_type
.
filter
(
getattr
(
self
,
field
),
strict
=
False
,
allow_downcast
=
True
)
return
theano
.
gof
.
wrapper
.
Wrap
(
**
wrap_dict
)
if
hasattr
(
self
,
'params_type'
)
and
isinstance
(
self
.
params_type
,
theano
.
gof
.
wrapper
.
Wrapper
):
wrapper
=
self
.
params_type
if
hasattr
(
self
,
'__props__'
)
and
all
(
field
in
self
.
__props__
for
field
in
wrapper
.
fields
):
wrap_dict
=
dict
()
for
i
in
range
(
wrapper
.
length
):
field
=
wrapper
.
fields
[
i
]
_type
=
wrapper
.
types
[
i
]
wrap_dict
[
field
]
=
_type
.
filter
(
getattr
(
self
,
field
),
strict
=
False
,
allow_downcast
=
True
)
return
theano
.
gof
.
wrapper
.
Wrap
(
wrapper
,
**
wrap_dict
)
raise
theano
.
gof
.
utils
.
MethodNotDefined
(
'get_params'
)
def
prepare_node
(
self
,
node
,
storage_map
,
compute_map
,
impl
):
...
...
theano/gof/tests/test_quadratic_function.c
0 → 100644
浏览文件 @
fd552f23
#section support_code_apply
int
APPLY_SPECIFIC
(
quadratic_function
)(
PyArrayObject
*
tensor
,
DTYPE_INPUT_0
a
,
DTYPE_INPUT_0
b
,
DTYPE_INPUT_0
c
)
{
NpyIter
*
iterator
=
NpyIter_New
(
tensor
,
NPY_ITER_READWRITE
|
NPY_ITER_EXTERNAL_LOOP
|
NPY_ITER_REFS_OK
,
NPY_KEEPORDER
,
NPY_NO_CASTING
,
NULL
);
if
(
iterator
==
NULL
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"Unable to iterate over a tensor for an elemwise operation."
);
return
-
1
;
}
NpyIter_IterNextFunc
*
get_next
=
NpyIter_GetIterNext
(
iterator
,
NULL
);
char
**
data_ptr
=
NpyIter_GetDataPtrArray
(
iterator
);
npy_intp
*
stride_ptr
=
NpyIter_GetInnerStrideArray
(
iterator
);
npy_intp
*
innersize_ptr
=
NpyIter_GetInnerLoopSizePtr
(
iterator
);
do
{
char
*
data
=
*
data_ptr
;
npy_intp
stride
=
*
stride_ptr
;
npy_intp
count
=
*
innersize_ptr
;
while
(
count
)
{
DTYPE_INPUT_0
x
=
*
((
DTYPE_INPUT_0
*
)
data
);
*
((
DTYPE_INPUT_0
*
)
data
)
=
a
*
x
*
x
+
b
*
x
+
c
;
data
+=
stride
;
--
count
;
}
}
while
(
get_next
(
iterator
));
NpyIter_Deallocate
(
iterator
);
return
0
;
}
int
APPLY_SPECIFIC
(
compute_quadratic
)(
PyArrayObject
*
X
,
PyArrayObject
**
Y
,
QUADRATIC_WRAPPER
*
coeff
)
{
DTYPE_INPUT_0
a
=
(
DTYPE_INPUT_0
)
(
*
(
COEFF_TYPE
*
)
PyArray_GETPTR1
(
coeff
->
a
,
0
));
// 0-D TensorType.
DTYPE_INPUT_0
b
=
coeff
->
b
;
// Scalar.
DTYPE_INPUT_0
c
=
(
DTYPE_INPUT_0
)
PyFloat_AsDouble
(
coeff
->
c
);
// Generic.
Py_XDECREF
(
*
Y
);
*
Y
=
(
PyArrayObject
*
)
PyArray_EMPTY
(
PyArray_NDIM
(
X
),
PyArray_DIMS
(
X
),
TYPENUM_INPUT_0
,
PyArray_IS_F_CONTIGUOUS
(
X
));
if
(
PyArray_CopyInto
(
*
Y
,
X
)
!=
0
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"Unable to copy input into output."
);
return
1
;
};
if
(
APPLY_SPECIFIC
(
quadratic_function
)(
*
Y
,
a
,
b
,
c
)
!=
0
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"Unable to compute quadratic function."
);
return
1
;
}
return
0
;
}
theano/gof/tests/test_wrapper.py
浏览文件 @
fd552f23
...
...
@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division
import
theano
import
numpy
from
unittest
import
TestCase
from
theano.gof
import
Op
,
Apply
from
theano.gof
import
Op
,
COp
,
Apply
from
theano
import
Generic
from
theano.scalar
import
Scalar
from
theano.tensor
import
TensorType
...
...
@@ -18,7 +18,7 @@ generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params.
class
Quadratic
Function
(
Op
):
class
Quadratic
OpFunc
(
Op
):
__props__
=
(
'a'
,
'b'
,
'c'
)
params_type
=
Wrapper
(
a
=
tensor_type_0d
,
b
=
scalar_type
,
...
...
@@ -39,7 +39,7 @@ class QuadraticFunction(Op):
y
[
0
]
=
coefficients
.
a
*
(
x
**
2
)
+
coefficients
.
b
*
x
+
coefficients
.
c
def
c_code_cache_version
(
self
):
return
(
1
,
2
)
return
(
1
,
3
)
def
c_support_code_apply
(
self
,
node
,
name
):
float_type
=
node
.
inputs
[
0
]
.
type
.
dtype_specs
()[
1
]
...
...
@@ -82,9 +82,9 @@ class QuadraticFunction(Op):
float_typenum
=
numpy
.
dtype
(
node
.
inputs
[
0
]
.
type
.
dtype
)
.
num
coeff_type
=
'npy_'
+
numpy
.
dtype
(
dtype
)
.
name
return
"""
%(float_type)
s a = (
%(float_type)
s) (*(
%(coeff_type)
s*) PyArray_GETPTR1(
%(coeff)
s
.
a, 0)); // 0-D TensorType.
%(float_type)
s b =
%(coeff)
s
.
b; // Scalar.
%(float_type)
s c = (
%(float_type)
s)PyFloat_AsDouble(
%(coeff)
s
.
c); // Generic.
%(float_type)
s a = (
%(float_type)
s) (*(
%(coeff_type)
s*) PyArray_GETPTR1(
%(coeff)
s
->
a, 0)); // 0-D TensorType.
%(float_type)
s b =
%(coeff)
s
->
b; // Scalar.
%(float_type)
s c = (
%(float_type)
s)PyFloat_AsDouble(
%(coeff)
s
->
c); // Generic.
Py_XDECREF(
%(Y)
s);
%(Y)
s = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(
%(X)
s), PyArray_DIMS(
%(X)
s),
%(float_typenum)
s, PyArray_IS_F_CONTIGUOUS(
%(X)
s));
if (PyArray_CopyInto(
%(Y)
s,
%(X)
s) != 0) {
...
...
@@ -98,29 +98,54 @@ class QuadraticFunction(Op):
"""
%
locals
()
# Same op as above, but implemented as a COp (with C code in an external file).
class
QuadraticCOpFunc
(
COp
):
__props__
=
(
'a'
,
'b'
,
'c'
)
params_type
=
Wrapper
(
a
=
tensor_type_0d
,
b
=
scalar_type
,
c
=
generic_type
)
def
get_op_params
(
self
):
return
[(
'QUADRATIC_WRAPPER'
,
self
.
params_type
.
name
),
(
'COEFF_TYPE'
,
'npy_'
+
numpy
.
dtype
(
dtype
)
.
name
)]
def
__init__
(
self
,
a
,
b
,
c
):
super
(
QuadraticCOpFunc
,
self
)
.
__init__
(
'test_quadratic_function.c'
,
'APPLY_SPECIFIC(compute_quadratic)'
)
self
.
a
=
a
self
.
b
=
b
self
.
c
=
c
def
make_node
(
self
,
x
):
x
=
tensor
.
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
class
TestWrapper
(
TestCase
):
def
test_wrap_hash_and_eq
(
self
):
w1
=
Wrap
(
a
=
1
,
b
=
'test string'
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
w2
=
Wrap
(
a
=
1
,
b
=
'test string'
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
def
test_hash_and_eq_wrap
(
self
):
wp1
=
Wrapper
(
a
=
Generic
(),
array
=
TensorType
(
'int32'
,
(
False
,)),
floatting
=
Scalar
(
'float32'
),
npy_scalar
=
TensorType
(
'float64'
,
tuple
()))
wp2
=
Wrapper
(
a
=
Generic
(),
array
=
TensorType
(
'int32'
,
(
False
,)),
floatting
=
Scalar
(
'float32'
),
npy_scalar
=
TensorType
(
'float64'
,
tuple
()))
w1
=
Wrap
(
wp1
,
a
=
1
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
w2
=
Wrap
(
wp2
,
a
=
1
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
assert
w1
==
w2
assert
not
(
w1
!=
w2
)
assert
hash
(
w1
)
==
hash
(
w2
)
assert
all
(
hasattr
(
w1
,
key
)
for
key
in
(
'a'
,
'b'
,
'array'
,
'floatting'
,
'npy_scalar'
))
# Changing attributes names only.
w2
=
Wrap
(
other_name
=
1
,
b
=
'test string'
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
assert
w1
!=
w2
# Changing attributes types only.
w2
=
Wrap
(
a
=
1
,
b
=
'test string'
,
array
=
[
1
,
2
,
4
,
5
,
7
],
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
# Changing attributes names only (a -> other_name).
wp2_other
=
Wrapper
(
other_name
=
Generic
(),
array
=
TensorType
(
'int32'
,
(
False
,)),
floatting
=
Scalar
(
'float32'
),
npy_scalar
=
TensorType
(
'float64'
,
tuple
()))
w2
=
Wrap
(
wp2_other
,
other_name
=
1
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
assert
w1
!=
w2
# Changing attributes values only.
w2
=
Wrap
(
a
=
1
,
b
=
'string'
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
# Changing attributes values only
(now a=2)
.
w2
=
Wrap
(
wp2
,
a
=
2
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
assert
w1
!=
w2
# Changing NumPy array values.
w2
=
Wrap
(
a
=
1
,
b
=
'test string'
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
-
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
# Changing NumPy array values
(5 -> -5)
.
w2
=
Wrap
(
wp2
,
a
=
1
,
array
=
numpy
.
asarray
([
1
,
2
,
4
,
-
5
,
7
]),
floatting
=-
4.5
,
npy_scalar
=
numpy
.
asarray
(
12
))
assert
w1
!=
w2
def
test_
wrapper_hash_and_eq
(
self
):
def
test_
hash_and_eq_wrapper
(
self
):
w1
=
Wrapper
(
a1
=
TensorType
(
'int64'
,
(
False
,
False
)),
a2
=
TensorType
(
'int64'
,
(
False
,
True
,
False
,
False
,
True
)),
a3
=
Generic
())
...
...
@@ -133,7 +158,7 @@ class TestWrapper(TestCase):
assert
w1
.
name
==
w2
.
name
# Changing attributes names only.
w2
=
Wrapper
(
a1
=
TensorType
(
'int64'
,
(
False
,
False
)),
other_name
=
TensorType
(
'int64'
,
(
False
,
True
,
False
,
False
,
True
)),
other_name
=
TensorType
(
'int64'
,
(
False
,
True
,
False
,
False
,
True
)),
# a2 -> other_name
a3
=
Generic
())
assert
w1
!=
w2
# Changing attributes types only.
...
...
@@ -157,7 +182,8 @@ class TestWrapper(TestCase):
a3
=
Generic
())
# With a value that does not match the wrapper.
o
=
Wrap
(
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'int64'
),
o
=
Wrap
(
w
,
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'int64'
),
a2
=
random_tensor
.
astype
(
'float32'
),
a3
=
2000
)
# should fail (o.a1 is not int32, o.a2 is not float64)
...
...
@@ -168,7 +194,8 @@ class TestWrapper(TestCase):
w
.
filter
(
o
,
strict
=
False
,
allow_downcast
=
True
)
# With a value that matches the wrapper.
o1
=
Wrap
(
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'int32'
),
o1
=
Wrap
(
w
,
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'int32'
),
a2
=
random_tensor
.
astype
(
'float64'
),
a3
=
2000
)
# All should pass.
...
...
@@ -176,15 +203,17 @@ class TestWrapper(TestCase):
w
.
filter
(
o1
,
strict
=
False
,
allow_downcast
=
False
)
w
.
filter
(
o1
,
strict
=
False
,
allow_downcast
=
True
)
# Check value_eq and value_eq_approx.
o2
=
Wrap
(
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'int32'
),
# Check values_eq and values_eq_approx.
o2
=
Wrap
(
w
,
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'int32'
),
a2
=
random_tensor
.
astype
(
'float64'
),
a3
=
2000
)
assert
w
.
values_eq
(
o1
,
o2
)
assert
w
.
values_eq_approx
(
o1
,
o2
)
# Check value_eq_approx.
o3
=
Wrap
(
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'float32'
),
o3
=
Wrap
(
w
,
a1
=
numpy
.
asarray
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
7
,
8
,
9
,
10
,
11
,
12
]])
.
astype
(
'float32'
),
a2
=
random_tensor
.
astype
(
'float64'
),
a3
=
2000.0
)
assert
w
.
values_eq_approx
(
o1
,
o3
)
...
...
@@ -192,14 +221,18 @@ class TestWrapper(TestCase):
def
test_op_params
(
self
):
a
,
b
,
c
=
2
,
3
,
-
7
x
=
tensor
.
matrix
()
y
=
QuadraticFunction
(
a
,
b
,
c
)(
x
)
f
=
theano
.
function
([
x
],
y
)
y1
=
QuadraticOpFunc
(
a
,
b
,
c
)(
x
)
y2
=
QuadraticCOpFunc
(
a
,
b
,
c
)(
x
)
f1
=
theano
.
function
([
x
],
y1
)
f2
=
theano
.
function
([
x
],
y2
)
shape
=
(
100
,
100
)
# The for-loop is here just to force profiling print something interesting.
# When running this test without this loop, profiling does not print neither list of classes nor list of ops
# (maybe because the function is extremely fast ?).
for
i
in
range
(
50
):
vx
=
numpy
.
random
.
normal
(
size
=
shape
[
0
]
*
shape
[
1
])
.
astype
(
dtype
)
.
reshape
(
*
shape
)
vy
=
f
(
vx
)
vy1
=
f1
(
vx
)
vy2
=
f2
(
vx
)
ref
=
a
*
(
vx
**
2
)
+
b
*
vx
+
c
utt
.
assert_allclose
(
ref
,
vy
)
utt
.
assert_allclose
(
vy1
,
vy2
)
utt
.
assert_allclose
(
ref
,
vy1
)
theano/gof/wrapper.py
浏览文件 @
fd552f23
"""
Module for wrapping many Theano variables into one struct for op params.
Module for wrapping many Theano variables into one
C
struct for op params.
This module contains two classes:
- Wrapper: class to define the op params type.
- Wrap: internal convenient class to create an object that is compatible with a Wrapper-defined op params.
Example of usage:
- :class:`Wrapper`: main class to define the op params type.
- :class:`Wrap`: internal convenient class to create an object that is compatible with Wrapper-defined op params.
Importation:
Example of usage
----------------
from theano.gof.wrapper import Wrapper
Importation:
In an op you create:
.. code-block:: python
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=TensorType('float64', (True, False)))
from theano.gof.wrapper import Wrapper
If your op contains props `attr1` AND `attr2`, the op.get_params() method will
automatically try to look for it and generate an appropriate wrapped struct.
The props must be able to pass the filtering (not strict, downcasting allowed)
of corresponding types defined into Wrapper.
In an op you create:
__props__ = ('attr1', 'attr2')
def __init__(value_attr1, value_attr2):
self.attr1 = value_attr1
self.attr2 = value_attr2
.. code-block:: python
In perform() implementation (with params named `param`):
from theano.tensor import TensorType, dmatrix
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=dmatrix)
var1 = param.attr1
var2 = param.attr2
If your op contains props ``attr1`` *and* ``attr2``, the default ``op.get_params()`` implementation
will automatically try to look for it and generate an appropriate wrapped struct.
Props must be compatible with the corresponding types defined into the Wrapper
(we will try to convert and downcast if needed).
In c_code() implementation (with `param = sub['params']`):
.. code-block:: python
PyArrayObject* attr1 = param.attr1;
PyArrayObject* attr2 = param.attr2;
/* You won't need to free them or whatever else. */
__props__ = ('attr1', 'attr2')
def __init__(value_attr1, value_attr2):
self.attr1 = value_attr1
self.attr2 = value_attr2
In ``perform()`` implementation (with params named ``param``):
See `theano/gof/tests/test_wrapper.py` for a complete working example.
.. code-block:: python
var1 = param.attr1
var2 = param.attr2
In ``c_code()`` implementation (with ``param = sub['params']``):
.. code-block:: c
PyArrayObject* attr1 = param->attr1;
PyArrayObject* attr2 = param->attr2;
/* You won't need to free them or whatever else. */
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_wrapper.py``
for complete working examples.
"""
from
__future__
import
absolute_import
,
print_function
,
division
import
re
import
hashlib
import
numpy
from
theano.gof.utils
import
MethodNotDefined
from
theano.gof.utils
import
MethodNotDefined
,
c_cpp_keywords
from
theano.gof
import
Type
from
theano.tensor.utils
import
hash_from_ndarray
# NB: Maybe we should check if an attribute name is a C/C++ keyword, and raise an error if so.
# These are some lists of C/C++ keywords:
# http://fr.cppreference.com/w/cpp/keyword
# http://fr.cppreference.com/w/c/keyword
class
Wrap
(
dict
):
...
...
@@ -60,19 +67,31 @@ class Wrap(dict):
Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable).
Example:
>>> w = Wrap(attr1=1, attr2=2.0, attri='3')
>>> print(w.attr1, w.attr2, w.attri)
>>> d = dict(a=1, b=2, c='test')
>>> w2 = Wrap(**d)
>>> print(w2.a, w2.b, w2.c)
**Example:**
.. code-block:: python
from theano.gof.wrapper import *
from theano.scalar import Scalar
# You must create a Wrapper first:
wp = Wrapper(attr1=Scalar('int32'), key2=Scalar('float32'), field3=Scalar('int64'))
# Then you can create a Wrap with the wrapper defined above and values for attributes.
w = Wrap(wp, attr1=1, key2=2.0, field3=3)
print(w.attr1, w.key2, w.field3)
d = dict(attr1=1, key2=2, field3=-1)
w2 = Wrap(wp, **d)
print(w2.attr1, w2.key2, w2.field3)
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
wrapper
,
**
kwargs
):
if
not
isinstance
(
wrapper
,
Wrapper
):
raise
TypeError
(
'Wrap: 1st constructor argument should be a Wrapper.'
)
for
field
in
wrapper
.
fields
:
if
field
not
in
kwargs
:
raise
TypeError
(
'Wrap: Wrapper attribute "
%
s" not in Wrap args.'
%
field
)
super
(
Wrap
,
self
)
.
__init__
(
**
kwargs
)
if
len
(
kwargs
)
==
0
:
raise
TypeError
(
'Wrap: cannot wrap empty data.'
)
self
.
__dict__
.
update
(
wrapper
=
wrapper
)
def
__repr__
(
self
):
return
'Wrap(
%
s)'
%
', '
.
join
([(
'
%
s:
%
s'
%
(
k
,
type
(
self
[
k
])))
for
k
in
sorted
(
self
.
keys
())])
...
...
@@ -82,33 +101,28 @@ class Wrap(dict):
raise
AttributeError
(
'Wrap: attribute "
%
s" does not exist.'
%
key
)
return
self
[
key
]
def
__setattr__
(
self
,
key
,
value
):
raise
NotImplementedError
(
'Wrap is immutable'
)
def
__setitem__
(
self
,
key
,
value
):
raise
NotImplementedError
(
'Wrap is immutable'
)
def
__delitem__
(
self
,
key
):
raise
NotImplementedError
(
'Wrap is immutable'
)
def
__hash__
(
self
):
keys
=
sorted
(
self
.
keys
())
types
=
[]
attributes
=
[]
for
k
in
keys
:
types
+=
(
type
(
self
[
k
]),)
if
isinstance
(
self
[
k
],
numpy
.
ndarray
):
# Note: hash_from_ndarray returns a string, so the hash is not yet complete
# (__hash__ must return an integer).
attributes
+=
(
hash_from_ndarray
(
self
[
k
]),)
else
:
# No checking, data should be hashable.
attributes
+=
(
self
[
k
],)
return
hash
((
type
(
self
),)
+
tuple
(
keys
)
+
tuple
(
types
)
+
tuple
(
attributes
))
return
hash
((
type
(
self
),
self
.
wrapper
)
+
tuple
(
# NB: Wrapped data should have been already filtered.
self
.
wrapper
.
types
[
i
]
.
make_constant
(
self
[
self
.
wrapper
.
fields
[
i
]])
.
signature
()
for
i
in
range
(
self
.
wrapper
.
length
)
))
def
__eq__
(
self
,
other
):
if
type
(
self
)
!=
type
(
other
)
or
len
(
self
)
!=
len
(
other
):
return
False
for
k
in
self
:
if
k
not
in
other
or
not
(
isinstance
(
self
[
k
],
type
(
other
[
k
]))
and
isinstance
(
other
[
k
],
type
(
self
[
k
]))):
return
False
if
isinstance
(
self
[
k
],
numpy
.
ndarray
):
if
not
numpy
.
allclose
(
self
[
k
],
other
[
k
]):
return
False
elif
self
[
k
]
!=
other
[
k
]:
return
False
return
True
return
(
type
(
self
)
==
type
(
other
)
and
self
.
wrapper
==
other
.
wrapper
and
all
(
# NB: Wrapped data should have been already filtered.
self
.
wrapper
.
types
[
i
]
.
values_eq
(
self
[
self
.
wrapper
.
fields
[
i
]],
other
[
self
.
wrapper
.
fields
[
i
]])
for
i
in
range
(
self
.
wrapper
.
length
)
))
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
...
...
@@ -121,18 +135,23 @@ class Wrapper(Type):
Wrapper constructor takes key-value args.
Key will be the name of the attribute in the struct.
Value is the Theano type of this attribute, ie. an instance of (a subclass of)
Type
(eg.
TensorType('int64', (False,))
).
Value is the Theano type of this attribute, ie. an instance of (a subclass of)
:class:`Type`
(eg.
``TensorType('int64', (False,))``
).
In a Python code any attribute named `key` will be available via:
structObject.key
In a Python code any attribute named ``key`` will be available via::
In a C code, attributes created to represent an instance of the type associated to `key` will be available via:
structObject.key
structObject.dtype_key # e.g. from TensorType C code.
structObject.other_attribute_named_from_key
etc.
In a C code, attributes created to represent an instance of the type associated to ``key`` will be available via:
.. code-block:: c
structObject->key;
structObject->dtype_key; // e.g. from TensorType C code.
structObject->other_attribute_named_from_key;
/* etc. */
**NB**: This Type is not a complete type and should never be used for regular graph operations.
"""
def
__init__
(
self
,
**
kwargs
):
...
...
@@ -142,10 +161,14 @@ class Wrapper(Type):
for
attribute_name
in
kwargs
:
if
re
.
match
(
'^[A-Za-z_][A-Za-z0-9_]*$'
,
attribute_name
)
is
None
:
raise
SyntaxError
(
'Wrapper: attribute "
%
s" should be a valid identifier.'
%
attribute_name
)
if
attribute_name
in
c_cpp_keywords
:
print
(
len
(
c_cpp_keywords
))
raise
SyntaxError
(
'Wrapper: "
%
s" is a potential C/C++ keyword and should not be used as attribute name.'
%
attribute_name
)
type_instance
=
kwargs
[
attribute_name
]
type_name
=
type_instance
.
__class__
.
__name__
if
not
isinstance
(
type_instance
,
Type
):
raise
TypeError
(
'Wrapper: attribute "
%
s" should inherit from
t
heano Type, got "
%
s".'
raise
TypeError
(
'Wrapper: attribute "
%
s" should inherit from
T
heano Type, got "
%
s".'
%
(
attribute_name
,
type_name
))
self
.
length
=
len
(
kwargs
)
...
...
@@ -164,49 +187,44 @@ class Wrapper(Type):
return
hash
((
type
(
self
),)
+
self
.
fields
+
self
.
types
)
def
generate_struct_name
(
self
):
""""
This method tries to generate an unique name for the current instance.
This name is intended to be used as struct name in C code and as constant
definition to check if a similar Wrapper has already been created
(see c_support_code() below).
"""
# This method tries to generate an unique name for the current instance.
# This name is intended to be used as struct name in C code and as constant
# definition to check if a similar Wrapper has already been created
# (see c_support_code() below).
fields_string
=
','
.
join
(
self
.
fields
)
.
encode
(
'utf-8'
)
types_string
=
','
.
join
(
str
(
t
)
for
t
in
self
.
types
)
.
encode
(
'utf-8'
)
fields_hex
=
hashlib
.
md5
(
fields_string
)
.
hexdigest
()
types_hex
=
hashlib
.
md5
(
types_string
)
.
hexdigest
()
return
'_wrapper_
%
s_
%
s'
%
(
fields_hex
,
types_hex
)
def
check_that_values_are_compatible
(
self
,
data
,
strict
,
allow_downcast
):
def
wrap_data
(
self
,
data
,
strict
,
allow_downcast
):
# Try to wrap data. Raise an exception if data does not respect the Wrapper's contract.
wrap_instance
=
dict
()
for
i
in
range
(
self
.
length
):
wrap_instance
[
self
.
fields
[
i
]]
=
self
.
types
[
i
]
.
filter
(
getattr
(
data
,
self
.
fields
[
i
]),
strict
,
allow_downcast
)
return
data
if
strict
else
Wrap
(
**
wrap_instance
)
return
data
if
strict
else
Wrap
(
self
,
**
wrap_instance
)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
if
strict
and
not
isinstance
(
data
,
Wrap
):
raise
TypeError
(
'
%
s: strict mode: data should be an instance of Wrap.'
%
self
)
return
self
.
check_that_values_are_compatible
(
data
,
strict
,
allow_downcast
)
return
self
.
wrap_data
(
data
,
strict
,
allow_downcast
)
def
values_eq
(
self
,
a
,
b
):
# We check that a and b have expected attributes and strict values.
a
=
self
.
filter
(
a
,
strict
=
True
)
b
=
self
.
filter
(
b
,
strict
=
True
)
# Then we compare.
for
i
in
range
(
self
.
length
):
if
not
self
.
types
[
i
]
.
values_eq
(
getattr
(
a
,
self
.
fields
[
i
]),
getattr
(
b
,
self
.
fields
[
i
])):
return
False
return
True
return
all
(
self
.
types
[
i
]
.
values_eq
(
getattr
(
a
,
self
.
fields
[
i
]),
getattr
(
b
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
))
def
values_eq_approx
(
self
,
a
,
b
):
# We check, wrap and round a and b if necessary.
a
=
self
.
filter
(
a
,
strict
=
False
,
allow_downcast
=
True
)
b
=
self
.
filter
(
b
,
strict
=
False
,
allow_downcast
=
True
)
# Then we compare.
for
i
in
range
(
self
.
length
):
if
not
self
.
types
[
i
]
.
values_eq_approx
(
getattr
(
a
,
self
.
fields
[
i
]),
getattr
(
b
,
self
.
fields
[
i
])):
return
False
return
True
return
all
(
self
.
types
[
i
]
.
values_eq_approx
(
getattr
(
a
,
self
.
fields
[
i
]),
getattr
(
b
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
))
def
c_compile_args
(
self
,
c_compiler
):
c_compile_args_list
=
[]
...
...
@@ -375,28 +393,40 @@ class Wrapper(Type):
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
(
1
,
4
)
return
(
1
,
5
)
# As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions,
# so it's better to work directly with pointers.
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
struct_name
=
self
.
name
return
"""
%(struct_name)
s
%(name)
s;
%(struct_name)
s
*
%(name)
s;
"""
%
locals
()
# c_init() and c_cleanup() are useless if we create the struct
# on stack, as struct class has constructor and destructor.
def
c_init
(
self
,
name
,
sub
):
return
""
# NB: It seems c_init() is not called for an op param.
# So the real initialization is done at top of c_extract.
return
"""
%(nams)
s = NULL;
"""
%
locals
()
def
c_cleanup
(
self
,
name
,
sub
):
return
""
return
"""
delete
%(name)
s;
%(name)
s = NULL;
"""
%
locals
()
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
):
struct_name
=
self
.
name
fail
=
sub
[
'fail'
]
length
=
self
.
length
fields_list
=
'"
%
s"'
%
'", "'
.
join
(
self
.
fields
)
return
"""
/* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)
s = new
%(struct_name)
s;
const char* fields[] = {
%(fields_list)
s};
if (py_
%(name)
s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
...
...
@@ -408,8 +438,8 @@ class Wrapper(Type):
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute
\\
"
%%
s
\\
" in object.", fields[i]);
%(fail)
s
}
%(name)
s
.
extract(o, i);
if (
%(name)
s
.
errorOccurred()) {
%(name)
s
->
extract(o, i);
if (
%(name)
s
->
errorOccurred()) {
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "
\\
nWrapper: error when extracting value for attribute
\\
"
%%
s
\\
".
\\
n", fields[i]);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论