Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
42e7ed18
提交
42e7ed18
authored
4月 07, 2017
作者:
notoraptor
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extend ParamsType: handle enumerations types.
Add public: * has_type() * get_field() * get_enum() * enum_from_alias() * get_params() Update documentation for ParamsType.
上级
87c35457
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
214 行增加
和
5 行删除
+214
-5
params_type.txt
doc/library/gof/params_type.txt
+2
-2
params_type.py
theano/gof/params_type.py
+212
-3
没有找到文件。
doc/library/gof/params_type.txt
浏览文件 @
42e7ed18
...
@@ -12,4 +12,5 @@ Reference
...
@@ -12,4 +12,5 @@ Reference
:platform: Unix, Windows
:platform: Unix, Windows
:synopsis: Wrapper class for op params
:synopsis: Wrapper class for op params
:members:
:members:
.. moduleauthor:: LISA
:member-order: bysource
\ No newline at end of file
.. moduleauthor:: LISA
theano/gof/params_type.py
浏览文件 @
42e7ed18
...
@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
...
@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
for complete working examples.
for complete working examples.
Combining ParamsType with Theano enumeration types
--------------------------------------------------
Theano provide some enumeration types that allow to create constant primitive values (integer and floating values)
available in both Python and C code. See :class:`theano.gof.type.EnumType` and its subclasses for more details.
If your ParamsType contains Theano enumeration types, then constants defined inside these
enumerations will be directly available as ParamsType attributes.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3'),
enum2=EnumType(PI=3.14, EPSILON=0.001))
# Each enum constant is available as a wrapper attribute:
print(wrapper.CONSTANT_1, wrapper.CONSTANT_2, wrapper.CONSTANT_3,
wrapper.PI, wrapper.EPSILON)
# For convenience, you can also look for a constant by name with
# ``ParamsType.get_enum()`` method.
pi = wrapper.get_enum('PI')
epsilon = wrapper.get_enum('EPSILON')
constant_2 = wrapper.get_enum('CONSTANT_2')
print(pi, epsilon, constant_2)
This implies that a ParamsType cannot contain different enum types with common enum names::
# Following line will raise an error,
# as there is a "CONSTANT_1" defined both in enum1 and enum2.
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2'),
enum2=EnumType(CONSTANT_1=0, CONSTANT_3=5))
If your enum types contain constant aliases, you can retrive them from ParamsType
with ``ParamsType.enum_from_alias(alias)`` method (see :class:`theano.gof.type.EnumType`
for more info about enumeration aliases).
.. code-block:: python
wrapper = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'),
enum2=EnumList(('D', 'delta'), 'E', 'F'))
b1 = wrapper.B
b2 = wrapper.get_enum('B')
b3 = wrapper.enum_from_alias('beta')
assert b1 == b2 == b3
"""
"""
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
import
re
import
re
import
hashlib
import
hashlib
from
theano.gof.utils
import
MethodNotDefined
,
c_cpp_keywords
from
theano.gof.utils
import
MethodNotDefined
,
c_cpp_keywords
from
theano.gof
import
Type
from
theano.gof
import
Type
,
EnumType
class
Params
(
dict
):
class
Params
(
dict
):
...
@@ -191,7 +238,28 @@ class ParamsType(Type):
...
@@ -191,7 +238,28 @@ class ParamsType(Type):
self
.
length
=
len
(
kwargs
)
self
.
length
=
len
(
kwargs
)
self
.
fields
=
tuple
(
sorted
(
kwargs
.
keys
()))
self
.
fields
=
tuple
(
sorted
(
kwargs
.
keys
()))
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
name
=
self
.
generate_struct_name
()
self
.
name
=
self
.
__generate_struct_name
()
self
.
__const_to_enum
=
{}
self
.
__alias_to_enum
=
{}
enum_types
=
[
t
for
t
in
self
.
types
if
isinstance
(
t
,
EnumType
)]
if
enum_types
:
# We don't want same enum names in different enum types.
if
sum
(
len
(
t
)
for
t
in
enum_types
)
!=
len
(
set
(
k
for
t
in
enum_types
for
k
in
t
)):
raise
AttributeError
(
'Wrapper: found different enum types with common constant names.'
)
# We don't want same aliases in different enum types.
if
sum
(
len
(
t
.
aliases
)
for
t
in
enum_types
)
!=
len
(
set
(
alias
for
t
in
enum_types
for
alias
in
t
.
aliases
)):
raise
AttributeError
(
'Wrapper: found different enum types with common constant aliases.'
)
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in Wrapper object directly.
self
.
__const_to_enum
=
{
enum_name
:
enum_type
for
enum_type
in
enum_types
for
enum_name
in
enum_type
}
self
.
__alias_to_enum
=
{
alias
:
enum_type
for
enum_type
in
enum_types
for
alias
in
enum_type
.
aliases
}
def
__getattr__
(
self
,
key
):
# Now we can access value of each enum defined inside enum types wrapped into the current Wrapper.
if
key
in
self
.
__const_to_enum
:
return
self
.
__const_to_enum
[
key
][
key
]
return
super
(
ParamsType
,
self
)
.
__getattr__
(
self
,
key
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'ParamsType<
%
s>'
%
', '
.
join
([(
'
%
s:
%
s'
%
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)])
return
'ParamsType<
%
s>'
%
', '
.
join
([(
'
%
s:
%
s'
%
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)])
...
@@ -202,7 +270,7 @@ class ParamsType(Type):
...
@@ -202,7 +270,7 @@ class ParamsType(Type):
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
((
type
(
self
),)
+
self
.
fields
+
self
.
types
)
return
hash
((
type
(
self
),)
+
self
.
fields
+
self
.
types
)
def
generate_struct_name
(
self
):
def
__
generate_struct_name
(
self
):
# This method tries to generate an unique name for the current instance.
# This method tries to generate an unique name for the current instance.
# This name is intended to be used as struct name in C code and as constant
# This name is intended to be used as struct name in C code and as constant
# definition to check if a similar ParamsType has already been created
# definition to check if a similar ParamsType has already been created
...
@@ -213,6 +281,147 @@ class ParamsType(Type):
...
@@ -213,6 +281,147 @@ class ParamsType(Type):
types_hex
=
hashlib
.
md5
(
types_string
)
.
hexdigest
()
types_hex
=
hashlib
.
md5
(
types_string
)
.
hexdigest
()
return
'_Params_
%
s_
%
s'
%
(
fields_hex
,
types_hex
)
return
'_Params_
%
s_
%
s'
%
(
fields_hex
,
types_hex
)
def
has_type
(
self
,
theano_type
):
"""
Return True if current ParamsType contains the specified Theano type.
"""
return
theano_type
in
self
.
types
def
get_field
(
self
,
theano_type
):
"""
Return the name (string) of the first field associated to
the given Theano type. Fields are sorted in lexicographic
order. Raise an exception if this Theano type is not
in the current ParamsType.
This method is intended to be used to retrieve a field name
when we know that current ParamsType contains the given
Theano type only once.
"""
return
self
.
fields
[
self
.
types
.
index
(
theano_type
)]
def
get_enum
(
self
,
key
):
"""
Look for a constant named ``key`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either the constant is not found or
current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=1, B=2, C=3),
digits=EnumList('ZERO', 'ONE', 'TWO'))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
# You can also directly do:
print(wrapper.C)
print(wrapper.TWO)
"""
return
self
.
__const_to_enum
[
key
][
key
]
def
enum_from_alias
(
self
,
alias
):
"""
Look for a constant that has alias ``alias`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either
1. there is no constant with this alias,
2. there is no constant which name is this alias, or
3. current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=(1, 'alpha'), B=(2, 'beta'), C=3),
digits=EnumList(('ZERO', 'nothing'), ('ONE', 'unit'), ('TWO', 'couple')))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
print(wrapper.enum_from_alias('alpha')) # 1
print(wrapper.enum_from_alias('nothing')) # 0
# For the following, alias 'C' is not defined, so the method looks for
# a constant named 'C', and finds it.
print(wrapper.enum_from_alias('C')) # 3
.. note::
Unlike with constant names, you can **NOT** access constants values directly with aliases through
ParamsType (ie. you can't write ``wrapper.alpha``). You **must** use ``wrapper.enum_from_alias()``
method to do that.
"""
return
self
.
__alias_to_enum
[
alias
]
.
fromalias
(
alias
)
if
alias
in
self
.
__alias_to_enum
else
self
.
__const_to_enum
[
alias
][
alias
]
def
get_params
(
self
,
*
objects
,
**
kwargs
):
"""
Convenient method to extract fields values from a list of Python objects and key-value args,
and wrap them into a :class:`Params` object compatible with current ParamsType.
For each field defined in the current ParamsType, a value for this field
is looked for in the given objects attributes (looking for attributes with this field name)
and key-values args (looking for a key equal to this field name), from left to right
(first object, then, ..., then last object, then key-value args), replacing a previous
field value found with any value found in next step, so that only the last field value
found is retained.
Fields values given in objects and kwargs must be compatible with types
associated to corresponding fields in current ParamsType.
**Example**::
import numpy
from theano.gof import ParamsType
from theano.tensor import dmatrix
from theano.scalar import Scalar
class MyObject:
def __init__(self):
self.a = 10
self.b = numpy.asarray([[1, 2, 3], [4, 5, 6]])
params_type = ParamsType(a=Scalar('int32'), b=dmatrix, c=Scalar('bool'))
o = MyObject()
value_for_c = False
# Value for c can't be retrieved from o, so we add a value for that field in kwargs.
params1 = params_type.get_params(o, c=value_for_c)
# params.a contains 10
# params.b contains [[1, 2, 3], [4, 5, 6]]
# params.c contains value_for_c
print(params)
"""
fields_values
=
dict
()
# We collect fields values from given objects.
# If a field is present in many objects, only the field in the last object will be retained.
for
obj
in
objects
:
for
field
in
self
.
fields
:
try
:
fields_values
[
field
]
=
getattr
(
obj
,
field
)
except
Exception
:
pass
# We then collect fields values from given kwargs.
# A field value in kwargs will replace any previous value collected from objects for this field.
for
field
in
self
.
fields
:
if
field
in
kwargs
:
fields_values
[
field
]
=
kwargs
[
field
]
# Then we filter the fields values and we create the Params object.
filtered
=
{
self
.
fields
[
i
]:
self
.
types
[
i
]
.
filter
(
fields_values
[
self
.
fields
[
i
]],
strict
=
False
,
allow_downcast
=
True
)
for
i
in
range
(
self
.
length
)}
return
Params
(
self
,
**
filtered
)
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
if
strict
and
not
isinstance
(
data
,
Params
):
if
strict
and
not
isinstance
(
data
,
Params
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论