Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4025a2dc
提交
4025a2dc
authored
4月 18, 2017
作者:
Frédéric Bastien
提交者:
GitHub
4月 18, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5755 from notoraptor/op-param-gpudnnsoftmax
Use Op params for GpuDnnSoftmax
上级
a9b6985d
574d0203
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
453 行增加
和
56 行删除
+453
-56
params_type.txt
doc/library/gof/params_type.txt
+2
-2
op.py
theano/gof/op.py
+1
-1
params_type.py
theano/gof/params_type.py
+226
-3
test_params_type.py
theano/gof/tests/test_params_type.py
+30
-1
test_types.py
theano/gof/tests/test_types.py
+25
-7
type.py
theano/gof/type.py
+143
-18
dnn.py
theano/gpuarray/dnn.py
+18
-16
dnn_softmax.c
theano/gpuarray/dnn_softmax.c
+4
-4
dnn_softmax_grad.c
theano/gpuarray/dnn_softmax_grad.c
+4
-4
没有找到文件。
doc/library/gof/params_type.txt
浏览文件 @
4025a2dc
...
@@ -12,4 +12,5 @@ Reference
...
@@ -12,4 +12,5 @@ Reference
:platform: Unix, Windows
:platform: Unix, Windows
:synopsis: Wrapper class for op params
:synopsis: Wrapper class for op params
:members:
:members:
.. moduleauthor:: LISA
:member-order: bysource
\ No newline at end of file
.. moduleauthor:: LISA
theano/gof/op.py
浏览文件 @
4025a2dc
...
@@ -808,7 +808,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -808,7 +808,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
field
=
wrapper
.
fields
[
i
]
field
=
wrapper
.
fields
[
i
]
_type
=
wrapper
.
types
[
i
]
_type
=
wrapper
.
types
[
i
]
wrap_dict
[
field
]
=
_type
.
filter
(
getattr
(
self
,
field
),
strict
=
False
,
allow_downcast
=
True
)
wrap_dict
[
field
]
=
_type
.
filter
(
getattr
(
self
,
field
),
strict
=
False
,
allow_downcast
=
True
)
return
theano
.
gof
.
Params
(
wrapper
,
**
wrap_dict
)
return
self
.
params_type
.
get_params
(
self
)
raise
theano
.
gof
.
utils
.
MethodNotDefined
(
'get_params'
)
raise
theano
.
gof
.
utils
.
MethodNotDefined
(
'get_params'
)
def
prepare_node
(
self
,
node
,
storage_map
,
compute_map
,
impl
):
def
prepare_node
(
self
,
node
,
storage_map
,
compute_map
,
impl
):
...
...
theano/gof/params_type.py
浏览文件 @
4025a2dc
...
@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
...
@@ -63,13 +63,60 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
for complete working examples.
for complete working examples.
Combining ParamsType with Theano enumeration types
--------------------------------------------------
Theano provide some enumeration types that allow to create constant primitive values (integer and floating values)
available in both Python and C code. See :class:`theano.gof.type.EnumType` and its subclasses for more details.
If your ParamsType contains Theano enumeration types, then constants defined inside these
enumerations will be directly available as ParamsType attributes.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3'),
enum2=EnumType(PI=3.14, EPSILON=0.001))
# Each enum constant is available as a wrapper attribute:
print(wrapper.CONSTANT_1, wrapper.CONSTANT_2, wrapper.CONSTANT_3,
wrapper.PI, wrapper.EPSILON)
# For convenience, you can also look for a constant by name with
# ``ParamsType.get_enum()`` method.
pi = wrapper.get_enum('PI')
epsilon = wrapper.get_enum('EPSILON')
constant_2 = wrapper.get_enum('CONSTANT_2')
print(pi, epsilon, constant_2)
This implies that a ParamsType cannot contain different enum types with common enum names::
# Following line will raise an error,
# as there is a "CONSTANT_1" defined both in enum1 and enum2.
wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2'),
enum2=EnumType(CONSTANT_1=0, CONSTANT_3=5))
If your enum types contain constant aliases, you can retrive them from ParamsType
with ``ParamsType.enum_from_alias(alias)`` method (see :class:`theano.gof.type.EnumType`
for more info about enumeration aliases).
.. code-block:: python
wrapper = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'),
enum2=EnumList(('D', 'delta'), 'E', 'F'))
b1 = wrapper.B
b2 = wrapper.get_enum('B')
b3 = wrapper.enum_from_alias('beta')
assert b1 == b2 == b3
"""
"""
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
import
re
import
re
import
hashlib
import
hashlib
from
theano.gof.utils
import
MethodNotDefined
,
c_cpp_keywords
from
theano.gof.utils
import
MethodNotDefined
,
c_cpp_keywords
from
theano.gof
import
Type
from
theano.gof
import
Type
,
EnumType
class
Params
(
dict
):
class
Params
(
dict
):
...
@@ -193,6 +240,32 @@ class ParamsType(Type):
...
@@ -193,6 +240,32 @@ class ParamsType(Type):
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
types
=
tuple
(
kwargs
[
field
]
for
field
in
self
.
fields
)
self
.
name
=
self
.
generate_struct_name
()
self
.
name
=
self
.
generate_struct_name
()
self
.
__const_to_enum
=
{}
self
.
__alias_to_enum
=
{}
enum_types
=
[
t
for
t
in
self
.
types
if
isinstance
(
t
,
EnumType
)]
if
enum_types
:
# We don't want same enum names in different enum types.
if
sum
(
len
(
t
)
for
t
in
enum_types
)
!=
len
(
set
(
k
for
t
in
enum_types
for
k
in
t
)):
raise
AttributeError
(
'ParamsType: found different enum types with common constants names.'
)
# We don't want same aliases in different enum types.
if
sum
(
len
(
t
.
aliases
)
for
t
in
enum_types
)
!=
len
(
set
(
alias
for
t
in
enum_types
for
alias
in
t
.
aliases
)):
raise
AttributeError
(
'ParamsType: found different enum types with common constants aliases.'
)
# We don't want aliases that have same names as some constants.
all_enums
=
{
e
for
t
in
enum_types
for
e
in
t
}
all_aliases
=
{
a
for
t
in
enum_types
for
a
in
t
.
aliases
}
if
[
a
for
a
in
all_aliases
if
a
in
all_enums
]:
raise
AttributeError
(
'ParamsType: found aliases that have same names as constants.'
)
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in Wrapper object directly.
self
.
__const_to_enum
=
{
enum_name
:
enum_type
for
enum_type
in
enum_types
for
enum_name
in
enum_type
}
self
.
__alias_to_enum
=
{
alias
:
enum_type
for
enum_type
in
enum_types
for
alias
in
enum_type
.
aliases
}
def
__getattr__
(
self
,
key
):
# Now we can access value of each enum defined inside enum types wrapped into the current Wrapper.
if
key
in
self
.
__const_to_enum
:
return
self
.
__const_to_enum
[
key
][
key
]
return
super
(
ParamsType
,
self
)
.
__getattr__
(
self
,
key
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'ParamsType<
%
s>'
%
', '
.
join
([(
'
%
s:
%
s'
%
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)])
return
'ParamsType<
%
s>'
%
', '
.
join
([(
'
%
s:
%
s'
%
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)])
...
@@ -213,6 +286,147 @@ class ParamsType(Type):
...
@@ -213,6 +286,147 @@ class ParamsType(Type):
types_hex
=
hashlib
.
md5
(
types_string
)
.
hexdigest
()
types_hex
=
hashlib
.
md5
(
types_string
)
.
hexdigest
()
return
'_Params_
%
s_
%
s'
%
(
fields_hex
,
types_hex
)
return
'_Params_
%
s_
%
s'
%
(
fields_hex
,
types_hex
)
def
has_type
(
self
,
theano_type
):
"""
Return True if current ParamsType contains the specified Theano type.
"""
return
theano_type
in
self
.
types
def
get_field
(
self
,
theano_type
):
"""
Return the name (string) of the first field associated to
the given Theano type. Fields are sorted in lexicographic
order. Raise an exception if this Theano type is not
in the current ParamsType.
This method is intended to be used to retrieve a field name
when we know that current ParamsType contains the given
Theano type only once.
"""
return
self
.
fields
[
self
.
types
.
index
(
theano_type
)]
def
get_enum
(
self
,
key
):
"""
Look for a constant named ``key`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either the constant is not found or
current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=1, B=2, C=3),
digits=EnumList('ZERO', 'ONE', 'TWO'))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
# You can also directly do:
print(wrapper.C)
print(wrapper.TWO)
"""
return
self
.
__const_to_enum
[
key
][
key
]
def
enum_from_alias
(
self
,
alias
):
"""
Look for a constant that has alias ``alias`` in the Theano enumeration types
wrapped into current ParamsType. Return value of the constant found,
or raise an exception if either
1. there is no constant with this alias,
2. there is no constant which name is this alias, or
3. current wrapper does not contain any Theano enumeration type.
**Example**::
from theano.gof import ParamsType, EnumType, EnumList
from theano.scalar import Scalar
wrapper = ParamsType(scalar=Scalar('int32'),
letters=EnumType(A=(1, 'alpha'), B=(2, 'beta'), C=3),
digits=EnumList(('ZERO', 'nothing'), ('ONE', 'unit'), ('TWO', 'couple')))
print(wrapper.get_enum('C')) # 3
print(wrapper.get_enum('TWO')) # 2
print(wrapper.enum_from_alias('alpha')) # 1
print(wrapper.enum_from_alias('nothing')) # 0
# For the following, alias 'C' is not defined, so the method looks for
# a constant named 'C', and finds it.
print(wrapper.enum_from_alias('C')) # 3
.. note::
Unlike with constant names, you can **NOT** access constants values directly with aliases through
ParamsType (ie. you can't write ``wrapper.alpha``). You **must** use ``wrapper.enum_from_alias()``
method to do that.
"""
return
self
.
__alias_to_enum
[
alias
]
.
fromalias
(
alias
)
if
alias
in
self
.
__alias_to_enum
else
self
.
__const_to_enum
[
alias
][
alias
]
def
get_params
(
self
,
*
objects
,
**
kwargs
):
"""
Convenient method to extract fields values from a list of Python objects and key-value args,
and wrap them into a :class:`Params` object compatible with current ParamsType.
For each field defined in the current ParamsType, a value for this field
is looked for in the given objects attributes (looking for attributes with this field name)
and key-values args (looking for a key equal to this field name), from left to right
(first object, then, ..., then last object, then key-value args), replacing a previous
field value found with any value found in next step, so that only the last field value
found is retained.
Fields values given in objects and kwargs must be compatible with types
associated to corresponding fields in current ParamsType.
**Example**::
import numpy
from theano.gof import ParamsType
from theano.tensor import dmatrix
from theano.scalar import Scalar
class MyObject:
def __init__(self):
self.a = 10
self.b = numpy.asarray([[1, 2, 3], [4, 5, 6]])
params_type = ParamsType(a=Scalar('int32'), b=dmatrix, c=Scalar('bool'))
o = MyObject()
value_for_c = False
# Value for c can't be retrieved from o, so we add a value for that field in kwargs.
params = params_type.get_params(o, c=value_for_c)
# params.a contains 10
# params.b contains [[1, 2, 3], [4, 5, 6]]
# params.c contains value_for_c
print(params)
"""
fields_values
=
dict
()
# We collect fields values from given objects.
# If a field is present in many objects, only the field in the last object will be retained.
for
obj
in
objects
:
for
field
in
self
.
fields
:
try
:
fields_values
[
field
]
=
getattr
(
obj
,
field
)
except
Exception
:
pass
# We then collect fields values from given kwargs.
# A field value in kwargs will replace any previous value collected from objects for this field.
for
field
in
self
.
fields
:
if
field
in
kwargs
:
fields_values
[
field
]
=
kwargs
[
field
]
# Then we filter the fields values and we create the Params object.
filtered
=
{
self
.
fields
[
i
]:
self
.
types
[
i
]
.
filter
(
fields_values
[
self
.
fields
[
i
]],
strict
=
False
,
allow_downcast
=
True
)
for
i
in
range
(
self
.
length
)}
return
Params
(
self
,
**
filtered
)
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
if
strict
and
not
isinstance
(
data
,
Params
):
if
strict
and
not
isinstance
(
data
,
Params
):
...
@@ -309,12 +523,18 @@ class ParamsType(Type):
...
@@ -309,12 +523,18 @@ class ParamsType(Type):
sub
=
{
'fail'
:
'{this->setErrorOccurred(); return;}'
}
sub
=
{
'fail'
:
'{this->setErrorOccurred(); return;}'
}
struct_name
=
self
.
name
struct_name
=
self
.
name
struct_name_defined
=
struct_name
.
upper
()
struct_name_defined
=
struct_name
.
upper
()
c_support_code_set
=
set
()
c_declare_list
=
[]
c_declare_list
=
[]
c_init_list
=
[]
c_init_list
=
[]
c_cleanup_list
=
[]
c_cleanup_list
=
[]
c_extract_list
=
[]
c_extract_list
=
[]
for
attribute_name
,
type_instance
in
zip
(
self
.
fields
,
self
.
types
):
for
attribute_name
,
type_instance
in
zip
(
self
.
fields
,
self
.
types
):
try
:
c_support_code_set
.
add
(
type_instance
.
c_support_code
())
except
MethodNotDefined
:
pass
c_declare_list
.
append
(
type_instance
.
c_declare
(
attribute_name
,
sub
))
c_declare_list
.
append
(
type_instance
.
c_declare
(
attribute_name
,
sub
))
c_init_list
.
append
(
type_instance
.
c_init
(
attribute_name
,
sub
))
c_init_list
.
append
(
type_instance
.
c_init
(
attribute_name
,
sub
))
...
@@ -330,6 +550,7 @@ class ParamsType(Type):
...
@@ -330,6 +550,7 @@ class ParamsType(Type):
'extract_code'
:
type_instance
.
c_extract
(
attribute_name
,
sub
)
'extract_code'
:
type_instance
.
c_extract
(
attribute_name
,
sub
)
})
})
support_code
=
'
\n
'
.
join
(
sorted
(
list
(
c_support_code_set
)))
struct_declare
=
'
\n
'
.
join
(
c_declare_list
)
struct_declare
=
'
\n
'
.
join
(
c_declare_list
)
struct_init
=
'
\n
'
.
join
(
c_init_list
)
struct_init
=
'
\n
'
.
join
(
c_init_list
)
struct_cleanup
=
'
\n
'
.
join
(
c_cleanup_list
)
struct_cleanup
=
'
\n
'
.
join
(
c_cleanup_list
)
...
@@ -350,6 +571,7 @@ class ParamsType(Type):
...
@@ -350,6 +571,7 @@ class ParamsType(Type):
[(
'case
%
d: extract_
%
s(object); break;'
%
(
i
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
)])
[(
'case
%
d: extract_
%
s(object); break;'
%
(
i
,
self
.
fields
[
i
]))
for
i
in
range
(
self
.
length
)])
)
)
return
"""
return
"""
%(support_code)
s
#ifndef
%(struct_name_defined)
s
#ifndef
%(struct_name_defined)
s
#define
%(struct_name_defined)
s
#define
%(struct_name_defined)
s
struct
%(struct_name)
s {
struct
%(struct_name)
s {
...
@@ -389,12 +611,13 @@ class ParamsType(Type):
...
@@ -389,12 +611,13 @@ class ParamsType(Type):
}
}
};
};
#endif
#endif
"""
%
dict
(
struct_name_defined
=
struct_name_defined
,
struct_name
=
struct_name
,
struct_declare
=
struct_declare
,
"""
%
dict
(
support_code
=
support_code
,
struct_name_defined
=
struct_name_defined
,
struct_name
=
struct_name
,
struct_declare
=
struct_declare
,
struct_init
=
struct_init
,
struct_cleanup
=
struct_cleanup
,
struct_extract
=
struct_extract
,
struct_init
=
struct_init
,
struct_cleanup
=
struct_cleanup
,
struct_extract
=
struct_extract
,
struct_extract_method
=
struct_extract_method
)
struct_extract_method
=
struct_extract_method
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
((
1
,
7
),
tuple
(
t
.
c_code_cache_version
()
for
t
in
self
.
types
))
return
((
1
,
8
),
tuple
(
t
.
c_code_cache_version
()
for
t
in
self
.
types
))
# As this struct has constructor and destructor, it could be instanciated on stack,
# As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions,
# but current implementations of C ops will then pass the instance by value at functions,
...
...
theano/gof/tests/test_params_type.py
浏览文件 @
4025a2dc
...
@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply
...
@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply
from
theano
import
Generic
from
theano
import
Generic
from
theano.scalar
import
Scalar
from
theano.scalar
import
Scalar
from
theano.tensor
import
TensorType
from
theano.tensor
import
TensorType
from
theano.gof
import
ParamsType
,
Params
from
theano.gof
import
ParamsType
,
Params
,
EnumList
from
theano
import
tensor
from
theano
import
tensor
from
theano.tests
import
unittest_tools
as
utt
from
theano.tests
import
unittest_tools
as
utt
...
@@ -213,6 +213,35 @@ class TestParamsType(TestCase):
...
@@ -213,6 +213,35 @@ class TestParamsType(TestCase):
a3
=
2000.0
-
0.00000000000000001
)
a3
=
2000.0
-
0.00000000000000001
)
assert
w
.
values_eq_approx
(
o1
,
o3
)
assert
w
.
values_eq_approx
(
o1
,
o3
)
def
test_params_type_with_enums
(
self
):
# Test that we fail if we create a params type with common enum names inside different enum types.
try
:
ParamsType
(
enum1
=
EnumList
(
'A'
,
'B'
,
'C'
),
enum2
=
EnumList
(
'A'
,
'B'
,
'F'
))
except
AttributeError
:
pass
else
:
raise
Exception
(
'ParamsType should fail with common enum names inside different enum types.'
)
# Test that we fail if we create a params type with common names in both aliases and constants.
try
:
ParamsType
(
enum1
=
EnumList
((
'A'
,
'a'
),
(
'B'
,
'b'
)),
enum2
=
EnumList
((
'ONE'
,
'a'
),
(
'TWO'
,
'two'
)))
except
AttributeError
:
ParamsType
(
enum1
=
EnumList
((
'A'
,
'a'
),
(
'B'
,
'b'
)),
enum2
=
EnumList
((
'ONE'
,
'one'
),
(
'TWO'
,
'two'
)))
else
:
raise
Exception
(
'ParamsType should fail when there are aliases with same names as some constants.'
)
# Test that we can access enum values through wrapper directly.
w
=
ParamsType
(
enum1
=
EnumList
(
'A'
,
(
'B'
,
'beta'
),
'C'
),
enum2
=
EnumList
((
'D'
,
'delta'
),
'E'
,
'F'
))
assert
w
.
A
==
0
and
w
.
B
==
1
and
w
.
C
==
2
assert
w
.
D
==
0
and
w
.
E
==
1
and
w
.
F
==
2
# Test constants access through aliases.
assert
w
.
enum_from_alias
(
'beta'
)
==
w
.
B
assert
w
.
enum_from_alias
(
'delta'
)
==
w
.
D
assert
w
.
enum_from_alias
(
'C'
)
==
w
.
C
# C is not an alias, so it should return a constant named C.
# Test that other regular wrapper attributes are still available.
assert
len
(
w
.
fields
)
==
len
(
w
.
types
)
==
w
.
length
assert
w
.
name
def
test_op_params
(
self
):
def
test_op_params
(
self
):
a
,
b
,
c
=
2
,
3
,
-
7
a
,
b
,
c
=
2
,
3
,
-
7
x
=
tensor
.
matrix
(
dtype
=
'float64'
)
x
=
tensor
.
matrix
(
dtype
=
'float64'
)
...
...
theano/gof/tests/test_types.py
浏览文件 @
4025a2dc
...
@@ -81,18 +81,19 @@ def test_cdata():
...
@@ -81,18 +81,19 @@ def test_cdata():
class
MyOpEnumList
(
Op
):
class
MyOpEnumList
(
Op
):
__props__
=
(
'op_chosen'
,)
__props__
=
(
'op_chosen'
,)
params_type
=
EnumList
(
'ADD'
,
'SUB'
,
'MULTIPLY'
,
'DIVIDE'
,
ctype
=
'unsigned long long'
)
params_type
=
EnumList
(
(
'ADD'
,
'+'
),
(
'SUB'
,
'-'
),
(
'MULTIPLY'
,
'*'
),
(
'DIVIDE'
,
'/'
)
,
ctype
=
'unsigned long long'
)
def
__init__
(
self
,
choose_op
):
def
__init__
(
self
,
choose_op
):
assert
self
.
params_type
.
ADD
==
0
assert
self
.
params_type
.
ADD
==
0
assert
self
.
params_type
.
SUB
==
1
assert
self
.
params_type
.
SUB
==
1
assert
self
.
params_type
.
MULTIPLY
==
2
assert
self
.
params_type
.
MULTIPLY
==
2
assert
self
.
params_type
.
DIVIDE
==
3
assert
self
.
params_type
.
DIVIDE
==
3
op_to_const
=
{
'+'
:
self
.
params_type
.
ADD
,
assert
self
.
params_type
.
fromalias
(
'+'
)
==
self
.
params_type
.
ADD
'-'
:
self
.
params_type
.
SUB
,
assert
self
.
params_type
.
fromalias
(
'-'
)
==
self
.
params_type
.
SUB
'*'
:
self
.
params_type
.
MULTIPLY
,
assert
self
.
params_type
.
fromalias
(
'*'
)
==
self
.
params_type
.
MULTIPLY
'/'
:
self
.
params_type
.
DIVIDE
}
assert
self
.
params_type
.
fromalias
(
'/'
)
==
self
.
params_type
.
DIVIDE
self
.
op_chosen
=
op_to_const
[
choose_op
]
assert
self
.
params_type
.
has_alias
(
choose_op
)
self
.
op_chosen
=
choose_op
def
get_params
(
self
,
node
):
def
get_params
(
self
,
node
):
return
self
.
op_chosen
return
self
.
op_chosen
...
@@ -204,7 +205,7 @@ class TestEnumTypes(TestCase):
...
@@ -204,7 +205,7 @@ class TestEnumTypes(TestCase):
# Check that invalid enum value raises exception.
# Check that invalid enum value raises exception.
try
:
try
:
EnumType
(
INVALID_VALUE
=
'string is not allowed.'
)
EnumType
(
INVALID_VALUE
=
'string is not allowed.'
)
except
Valu
eError
:
except
Typ
eError
:
pass
pass
else
:
else
:
raise
Exception
(
'EnumType with invalid value should fail.'
)
raise
Exception
(
'EnumType with invalid value should fail.'
)
...
@@ -218,6 +219,23 @@ class TestEnumTypes(TestCase):
...
@@ -218,6 +219,23 @@ class TestEnumTypes(TestCase):
# Check access to attributes.
# Check access to attributes.
assert
len
((
e1
.
ctype
,
e1
.
C1
,
e1
.
C2
,
e1
.
C3
,
e1
.
C4
,
e1
.
C5
,
e1
.
C6
))
==
7
assert
len
((
e1
.
ctype
,
e1
.
C1
,
e1
.
C2
,
e1
.
C3
,
e1
.
C4
,
e1
.
C5
,
e1
.
C6
))
==
7
# Check enum with aliases.
e1
=
EnumType
(
A
=
(
'alpha'
,
0
),
B
=
(
'beta'
,
1
),
C
=
2
)
e2
=
EnumType
(
A
=
(
'alpha'
,
0
),
B
=
(
'beta'
,
1
),
C
=
2
)
e3
=
EnumType
(
A
=
(
'a'
,
0
),
B
=
(
'beta'
,
1
),
C
=
2
)
assert
e1
==
e2
assert
e1
!=
e3
assert
e1
.
filter
(
'beta'
)
==
e1
.
fromalias
(
'beta'
)
==
e1
.
B
==
1
assert
e1
.
filter
(
'C'
)
==
e1
.
fromalias
(
'C'
)
==
e1
.
C
==
2
# Check that invalid alias (same as a constant) raises exception.
try
:
EnumList
((
'A'
,
'a'
),
(
'B'
,
'B'
))
except
TypeError
:
EnumList
((
'A'
,
'a'
),
(
'B'
,
'b'
))
else
:
raise
Exception
(
'Enum with an alias name equal to a constant name should fail.'
)
def
test_op_with_enumlist
(
self
):
def
test_op_with_enumlist
(
self
):
a
=
scalar
.
int32
()
a
=
scalar
.
int32
()
b
=
scalar
.
int32
()
b
=
scalar
.
int32
()
...
...
theano/gof/type.py
浏览文件 @
4025a2dc
...
@@ -814,13 +814,20 @@ CDataType.Constant = CDataTypeConstant
...
@@ -814,13 +814,20 @@ CDataType.Constant = CDataTypeConstant
class
EnumType
(
Type
,
dict
):
class
EnumType
(
Type
,
dict
):
"""
"""
Main subclasses:
- :class:`EnumList`
- :class:`CEnumType`
Op parameter class that allows to create enumerations of constant values.
Op parameter class that allows to create enumerations of constant values.
- Constants are available as object attributes in Python code and as macro-defined constants in C code.
- Constants are available as object attributes in Python code and as macro-defined constants in C code.
- Constants can be floating values, integers, or booleans (automatically converted to integers).
- Constants can be floating values, integers, or booleans (automatically converted to integers).
- Constants name must start with a capital letter and contain capital letters, underscores or digits.
- Constants name must start with a capital letter and contain capital letters, underscores or digits.
- A constant can have an alias, and then be available through both constant name and constant alias.
Example::
**Example**
.. code-block:: python
enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True)
enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4)
...
@@ -849,35 +856,111 @@ class EnumType(Type, dict):
...
@@ -849,35 +856,111 @@ class EnumType(Type, dict):
size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0
size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0
**Example with aliases**
When creating an enum, you can give some aliases to specific constants while keeping other constants without aliases.
An alias must be a string, and there is currently no string format constraints.
To give an alias to a constant in the EnumType constructor, use the following key-value syntax::
constant_name=(constant_alias, constant_value)
You can then retrieve a constant from an alias with method ``EnumType.fromalias()``.
Aliases are intended to be used in Python code only (only constants names are available in C code).
Especially, an alias will be recognized by ``Enumtype.filter()`` method with non-strict filtering,
allowing a maximum flexibility for converting strings to numeric constants available in Python and C code.
.. code-block:: python
from theano.gof import EnumType
# You can remark that constant 'C' does not have an alias.
enum = EnumType(A=('alpha', 1), B=('beta', 2), C=3, D=('delta', 4))
# Constants are all directly available by name.
print(enum.A, enum.B, enum.C, enum.D)
# But we can also now get some constants by alias.
a = enum.fromalias('alpha')
b = enum.fromalias('beta')
d = enum.fromalias('delta')
# If method fromalias() receives an unknown alias,
# it will looks for a constant with this alias
# as exact constant name.
c = enum.fromalias('C') # will get enum.C
# An alias defined in an EnumType will be correctly converted with non-strict filtering.
value = enum.filter('delta', strict=False)
# value now contaisn enum.D, ie. 4.
.. note::
.. note::
This Type (and subclasses) is not complete and should never be used for regular graph operations.
This Type (and subclasses) is not complete and should never be used for regular graph operations.
"""
"""
def
check_ctype
(
self
):
def
__init_ctype
(
self
,
ctype
):
# C type may be a list of keywords, e.g. "unsigned long long".
# C type may be a list of keywords, e.g. "unsigned long long".
# We should check each part.
# We should check each part.
if
not
all
(
re
.
match
(
'^[A-Za-z_][A-Za-z0-9_]*$'
,
el
)
for
el
in
self
.
ctype
.
split
()):
ctype_parts
=
ctype
.
split
()
raise
TypeError
(
'
%
s: invalid C type'
%
type
(
self
)
.
__name__
)
if
not
all
(
re
.
match
(
'^[A-Za-z_][A-Za-z0-9_]*$'
,
el
)
for
el
in
ctype_parts
):
raise
TypeError
(
'
%
s: invalid C type.'
%
type
(
self
)
.
__name__
)
self
.
ctype
=
' '
.
join
(
ctype_parts
)
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
self
.
ctype
=
kwargs
.
pop
(
'ctype'
,
'double'
)
self
.
__init_ctype
(
kwargs
.
pop
(
'ctype'
,
'double'
)
)
self
.
check_ctype
()
self
.
aliases
=
dict
()
for
k
in
kwargs
:
for
k
in
kwargs
:
if
re
.
match
(
'^[A-Z][A-Z0-9_]*$'
,
k
)
is
None
:
if
re
.
match
(
'^[A-Z][A-Z0-9_]*$'
,
k
)
is
None
:
raise
AttributeError
(
'
%
s: invalid enum name: "
%
s". '
raise
AttributeError
(
'
%
s: invalid enum name: "
%
s". '
'Only capital letters, underscores and digits '
'Only capital letters, underscores and digits '
'are allowed.'
%
(
type
(
self
)
.
__name__
,
k
))
'are allowed.'
%
(
type
(
self
)
.
__name__
,
k
))
if
isinstance
(
kwargs
[
k
],
(
list
,
tuple
)):
if
len
(
kwargs
[
k
])
!=
2
:
raise
TypeError
(
'
%
s: when using a tuple to define a constant, your tuple should contain 2 values: '
'constant alias followed by constant value.'
%
type
(
self
)
.
__name__
)
alias
,
value
=
kwargs
[
k
]
if
not
isinstance
(
alias
,
str
):
raise
TypeError
(
'
%
s: constant alias should be a string, got "
%
s".'
%
(
type
(
self
)
.
__name__
,
alias
))
if
alias
==
k
:
raise
TypeError
(
"
%
s: it's useless to create an alias "
"with the same name as its associated constant."
%
type
(
self
)
.
__name__
)
if
alias
in
self
.
aliases
:
raise
TypeError
(
'
%
s: consant alias "
%
s" already used.'
%
(
type
(
self
)
.
__name__
,
alias
))
self
.
aliases
[
alias
]
=
k
kwargs
[
k
]
=
value
if
isinstance
(
kwargs
[
k
],
bool
):
if
isinstance
(
kwargs
[
k
],
bool
):
kwargs
[
k
]
=
int
(
kwargs
[
k
])
kwargs
[
k
]
=
int
(
kwargs
[
k
])
elif
not
isinstance
(
kwargs
[
k
],
(
int
,
float
)):
elif
not
isinstance
(
kwargs
[
k
],
(
int
,
float
)):
raise
ValueError
(
'
%
s: constant "
%
s": expected integer or floating value, got "
%
s".'
raise
TypeError
(
'
%
s: constant "
%
s": expected integer or floating value, got "
%
s".'
%
(
type
(
self
)
.
__name__
,
k
,
type
(
kwargs
[
k
])
.
__name__
))
%
(
type
(
self
)
.
__name__
,
k
,
type
(
kwargs
[
k
])
.
__name__
))
if
[
a
for
a
in
self
.
aliases
if
a
in
self
]:
raise
TypeError
(
"
%
s: some aliases have same names as constants."
%
type
(
self
)
.
__name__
)
super
(
EnumType
,
self
)
.
__init__
(
**
kwargs
)
super
(
EnumType
,
self
)
.
__init__
(
**
kwargs
)
def
fromalias
(
self
,
alias
):
"""
Get a constant value by its alias.
If there is not such alias in this enum, look for a constant
with this alias as constant name.
"""
return
self
[
self
.
aliases
[
alias
]]
if
alias
in
self
.
aliases
else
self
[
alias
]
def
has_alias
(
self
,
alias
):
"""
return True if and only if this enum has this alias.
"""
return
alias
in
self
.
aliases
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'
%
s(
%
s)'
%
(
type
(
self
)
.
__name__
,
', '
.
join
(
'
%
s:
%
s'
%
(
k
,
self
[
k
])
for
k
in
sorted
(
self
.
keys
())))
names_to_aliases
=
{
constant_name
:
''
for
constant_name
in
self
}
for
alias
in
self
.
aliases
:
names_to_aliases
[
self
.
aliases
[
alias
]]
=
'(
%
s)'
%
alias
return
'
%
s<
%
s>(
%
s)'
%
(
type
(
self
)
.
__name__
,
self
.
ctype
,
', '
.
join
(
'
%
s
%
s:
%
s'
%
(
k
,
names_to_aliases
[
k
],
self
[
k
])
for
k
in
sorted
(
self
.
keys
())))
def
__getattr__
(
self
,
key
):
def
__getattr__
(
self
,
key
):
if
key
in
self
:
if
key
in
self
:
...
@@ -897,14 +980,19 @@ class EnumType(Type, dict):
...
@@ -897,14 +980,19 @@ class EnumType(Type, dict):
def
__hash__
(
self
):
def
__hash__
(
self
):
# All values are Python basic types, then easy to hash.
# All values are Python basic types, then easy to hash.
return
hash
((
type
(
self
),
self
.
ctype
)
+
tuple
((
k
,
self
[
k
])
for
k
in
sorted
(
self
.
keys
())))
return
hash
((
type
(
self
),
self
.
ctype
)
+
tuple
((
k
,
self
[
k
])
for
k
in
sorted
(
self
.
keys
()))
+
tuple
((
a
,
self
.
aliases
[
a
])
for
a
in
sorted
(
self
.
aliases
.
keys
())))
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
and
return
(
type
(
self
)
==
type
(
other
)
and
self
.
ctype
==
other
.
ctype
and
self
.
ctype
==
other
.
ctype
and
len
(
self
)
==
len
(
other
)
and
len
(
self
)
==
len
(
other
)
and
len
(
self
.
aliases
)
==
len
(
other
.
aliases
)
and
all
(
k
in
other
for
k
in
self
)
and
all
(
k
in
other
for
k
in
self
)
and
all
(
self
[
k
]
==
other
[
k
]
for
k
in
self
))
all
(
a
in
other
.
aliases
for
a
in
self
.
aliases
)
and
all
(
self
[
k
]
==
other
[
k
]
for
k
in
self
)
and
all
(
self
.
aliases
[
a
]
==
other
.
aliases
[
a
]
for
a
in
self
.
aliases
))
# EnumType should be used to create constants available in both Python and C code.
# EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types,
# However, for convenience, we make sure EnumType can have a value, like other common types,
...
@@ -912,8 +1000,13 @@ class EnumType(Type, dict):
...
@@ -912,8 +1000,13 @@ class EnumType(Type, dict):
# C type of value is defined in self.ctype.
# C type of value is defined in self.ctype.
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
if
not
strict
and
isinstance
(
data
,
bool
):
if
not
strict
:
data
=
int
(
data
)
if
isinstance
(
data
,
bool
):
data
=
int
(
data
)
elif
isinstance
(
data
,
str
):
# We now accept strings as data values.
# Strings should be a constant alias or a constant name.
data
=
self
.
fromalias
(
data
)
assert
data
in
self
.
values
()
assert
data
in
self
.
values
()
return
data
return
data
...
@@ -947,7 +1040,7 @@ class EnumType(Type, dict):
...
@@ -947,7 +1040,7 @@ class EnumType(Type, dict):
return
"""
%(ctype)
s
%(name)
s;"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
)
return
"""
%(ctype)
s
%(name)
s;"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
)
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
return
"
%(name)
s =
0;"
%
dict
(
name
=
nam
e
)
return
"
%(name)
s =
(
%(ctype)
s)0;"
%
dict
(
name
=
name
,
ctype
=
self
.
ctyp
e
)
def
c_cleanup
(
self
,
name
,
sub
):
def
c_cleanup
(
self
,
name
,
sub
):
return
""
return
""
...
@@ -965,11 +1058,14 @@ class EnumType(Type, dict):
...
@@ -965,11 +1058,14 @@ class EnumType(Type, dict):
"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
'fail'
])
"""
%
dict
(
ctype
=
self
.
ctype
,
name
=
name
,
fail
=
sub
[
'fail'
])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,
1
)
class
EnumList
(
EnumType
):
class
EnumList
(
EnumType
):
"""
"""
**Inherit from**:
- :class:`EnumType`
Op parameter class that allows to create enumeration of constant values.
Op parameter class that allows to create enumeration of constant values.
Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of
Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of
constants names (constant at index ``i`` in the list will receive value ``i``,
constants names (constant at index ``i`` in the list will receive value ``i``,
...
@@ -986,6 +1082,14 @@ class EnumList(EnumType):
...
@@ -986,6 +1082,14 @@ class EnumList(EnumType):
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int')
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int')
Like :class:`EnumType`, you can also add an alias to a constant, by replacing the only constant name
(e.g. ``'CONSTANT_NAME'``) by a couple with constant name first and constant alias second
(e.g. ``('CONSTANT_NAME', 'constant_alias')``).
.. code-block:: python
enum = EnumList(('A', 'alpha'), ('B', 'beta'), 'C', 'D', 'E', 'F', ('G', 'gamma'))
See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example.
See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example.
"""
"""
...
@@ -995,16 +1099,35 @@ class EnumList(EnumType):
...
@@ -995,16 +1099,35 @@ class EnumList(EnumType):
type
(
self
)
.
__name__
+
': expected 0 or only 1 extra parameter "ctype".'
type
(
self
)
.
__name__
+
': expected 0 or only 1 extra parameter "ctype".'
ctype
=
kwargs
.
pop
(
'ctype'
,
'int'
)
ctype
=
kwargs
.
pop
(
'ctype'
,
'int'
)
if
len
(
args
)
>
len
(
set
(
args
)):
for
arg_rank
,
arg
in
enumerate
(
args
):
raise
AttributeError
(
type
(
self
)
.
__name__
+
': some constants names are duplicated.'
)
if
isinstance
(
arg
,
(
list
,
tuple
)):
if
len
(
arg
)
!=
2
:
raise
TypeError
(
'
%
s: when using a tuple to define a constant, your tuple should contain 2 values: '
'constant name followed by constant alias.'
%
type
(
self
)
.
__name__
)
constant_name
,
constant_alias
=
arg
if
not
isinstance
(
constant_alias
,
str
):
raise
TypeError
(
'
%
s: constant alias should be a string, got "
%
s".'
%
(
type
(
self
)
.
__name__
,
constant_alias
))
constant_value
=
(
constant_alias
,
arg_rank
)
else
:
constant_name
=
arg
constant_value
=
arg_rank
if
not
isinstance
(
constant_name
,
str
):
raise
TypeError
(
'
%
s: constant name should be a string, got "
%
s".'
%
(
type
(
self
)
.
__name__
,
constant_name
))
if
constant_name
in
kwargs
:
raise
TypeError
(
'
%
s: constant name already used ("
%
s").'
%
(
type
(
self
)
.
__name__
,
constant_name
))
kwargs
[
constant_name
]
=
constant_value
kwargs
=
{
const_name
:
const_rank
for
(
const_rank
,
const_name
)
in
enumerate
(
args
)}
kwargs
.
update
(
ctype
=
ctype
)
kwargs
.
update
(
ctype
=
ctype
)
super
(
EnumList
,
self
)
.
__init__
(
**
kwargs
)
super
(
EnumList
,
self
)
.
__init__
(
**
kwargs
)
class
CEnumType
(
EnumList
):
class
CEnumType
(
EnumList
):
"""
"""
**Inherit from**:
- :class:`EnumList`
Op parameter class that allows to create enumeration of constant values that represent C-defined constants.
Op parameter class that allows to create enumeration of constant values that represent C-defined constants.
- Constant should have same names as in C.
- Constant should have same names as in C.
...
@@ -1020,6 +1143,8 @@ class CEnumType(EnumList):
...
@@ -1020,6 +1143,8 @@ class CEnumType(EnumList):
enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long')
enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long')
Like :class:`EnumList`, you can also add an alias to a constant, with same syntax as in :class:`EnumList`.
See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example.
See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example.
.. note::
.. note::
...
...
theano/gpuarray/dnn.py
浏览文件 @
4025a2dc
...
@@ -12,7 +12,7 @@ from theano import Op, Apply, tensor, config, Variable
...
@@ -12,7 +12,7 @@ from theano import Op, Apply, tensor, config, Variable
from
theano.scalar
import
as_scalar
,
constant
,
Log
,
get_scalar_type
from
theano.scalar
import
as_scalar
,
constant
,
Log
,
get_scalar_type
from
theano.tensor
import
as_tensor_variable
from
theano.tensor
import
as_tensor_variable
from
theano.gradient
import
DisconnectedType
,
grad_not_implemented
from
theano.gradient
import
DisconnectedType
,
grad_not_implemented
from
theano.gof
import
Optimizer
,
local_optimizer
,
COp
from
theano.gof
import
Optimizer
,
local_optimizer
,
COp
,
ParamsType
,
CEnumType
from
theano.gof.cmodule
import
GCC_compiler
from
theano.gof.cmodule
import
GCC_compiler
from
theano.gof.type
import
CDataType
,
Generic
from
theano.gof.type
import
CDataType
,
Generic
from
theano.compile
import
optdb
from
theano.compile
import
optdb
...
@@ -234,6 +234,11 @@ class DnnBase(COp):
...
@@ -234,6 +234,11 @@ class DnnBase(COp):
ptr
=
ctx
.
cudnn_handle
.
value
ptr
=
ctx
.
cudnn_handle
.
value
res
=
handle_type
.
make_value
(
ptr
)
res
=
handle_type
.
make_value
(
ptr
)
ctx
.
cudnn_handle_param
=
res
ctx
.
cudnn_handle_param
=
res
if
isinstance
(
self
.
params_type
,
ParamsType
):
if
not
self
.
params_type
.
has_type
(
handle_type
):
raise
TypeError
(
'DnnBase: params_type must take into account the cuDNN handle type.'
)
handle_field
=
self
.
params_type
.
get_field
(
handle_type
)
return
self
.
params_type
.
get_params
(
self
,
**
{
handle_field
:
ctx
.
cudnn_handle_param
})
return
ctx
.
cudnn_handle_param
return
ctx
.
cudnn_handle_param
def
__init__
(
self
,
files
=
None
,
c_func
=
None
):
def
__init__
(
self
,
files
=
None
,
c_func
=
None
):
...
@@ -1504,6 +1509,18 @@ class GpuDnnSoftmaxBase(DnnBase):
...
@@ -1504,6 +1509,18 @@ class GpuDnnSoftmaxBase(DnnBase):
"""
"""
__props__
=
(
'mode'
,
'algo'
)
__props__
=
(
'mode'
,
'algo'
)
# Neither inputs nor output types properties are used
# neither in dnn_base.c nor in dnn_softmax*.c,
# so we can disable input checking.
check_input
=
False
params_type
=
ParamsType
(
algo
=
CEnumType
((
'CUDNN_SOFTMAX_FAST'
,
'fast'
),
(
'CUDNN_SOFTMAX_LOG'
,
'log'
),
(
'CUDNN_SOFTMAX_ACCURATE'
,
'accurate'
),
ctype
=
'cudnnSoftmaxAlgorithm_t'
),
mode
=
CEnumType
((
'CUDNN_SOFTMAX_MODE_INSTANCE'
,
'instance'
),
(
'CUDNN_SOFTMAX_MODE_CHANNEL'
,
'channel'
),
ctype
=
'cudnnSoftmaxMode_t'
),
handle
=
handle_type
)
def
__init__
(
self
,
algo
,
mode
):
def
__init__
(
self
,
algo
,
mode
):
DnnBase
.
__init__
(
self
,
[
self
.
file
],
self
.
c_func
)
DnnBase
.
__init__
(
self
,
[
self
.
file
],
self
.
c_func
)
...
@@ -1520,21 +1537,6 @@ class GpuDnnSoftmaxBase(DnnBase):
...
@@ -1520,21 +1537,6 @@ class GpuDnnSoftmaxBase(DnnBase):
else
:
else
:
return
[
shape
[
1
]]
return
[
shape
[
1
]]
def
get_op_params
(
self
):
if
self
.
mode
==
'instance'
:
mode
=
"CUDNN_SOFTMAX_MODE_INSTANCE"
else
:
mode
=
"CUDNN_SOFTMAX_MODE_CHANNEL"
if
self
.
algo
==
'fast'
:
algo
=
"CUDNN_SOFTMAX_FAST"
elif
self
.
algo
==
'log'
:
algo
=
"CUDNN_SOFTMAX_LOG"
else
:
algo
=
"CUDNN_SOFTMAX_ACCURATE"
return
[(
"SOFTMAX_MODE"
,
mode
),
(
"SOFTMAX_ALGO"
,
algo
)]
class
GpuDnnSoftmax
(
GpuDnnSoftmaxBase
):
class
GpuDnnSoftmax
(
GpuDnnSoftmaxBase
):
...
...
theano/gpuarray/dnn_softmax.c
浏览文件 @
4025a2dc
...
@@ -35,7 +35,7 @@ if (APPLY_SPECIFIC(output) != NULL)
...
@@ -35,7 +35,7 @@ if (APPLY_SPECIFIC(output) != NULL)
int
APPLY_SPECIFIC
(
softmax
)(
PyGpuArrayObject
*
x
,
int
APPLY_SPECIFIC
(
softmax
)(
PyGpuArrayObject
*
x
,
PyGpuArrayObject
**
out
,
PyGpuArrayObject
**
out
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
wrapper
)
{
PyGpuContextObject
*
c
=
x
->
context
;
PyGpuContextObject
*
c
=
x
->
context
;
cudnnStatus_t
err
;
cudnnStatus_t
err
;
...
@@ -83,9 +83,9 @@ int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
...
@@ -83,9 +83,9 @@ int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
cuda_wait
((
*
out
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
cuda_wait
((
*
out
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnSoftmaxForward
(
err
=
cudnnSoftmaxForward
(
_
handle
,
wrapper
->
handle
,
SOFTMAX_ALGO
,
wrapper
->
algo
,
SOFTMAX_MODE
,
wrapper
->
mode
,
alpha
,
alpha
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
x
),
PyGpuArray_DEV_DATA
(
x
),
...
...
theano/gpuarray/dnn_softmax_grad.c
浏览文件 @
4025a2dc
...
@@ -46,7 +46,7 @@ if (APPLY_SPECIFIC(dx) != NULL)
...
@@ -46,7 +46,7 @@ if (APPLY_SPECIFIC(dx) != NULL)
int
APPLY_SPECIFIC
(
softmax_grad
)(
PyGpuArrayObject
*
dy
,
int
APPLY_SPECIFIC
(
softmax_grad
)(
PyGpuArrayObject
*
dy
,
PyGpuArrayObject
*
sm
,
PyGpuArrayObject
*
sm
,
PyGpuArrayObject
**
dx
,
PyGpuArrayObject
**
dx
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
wrapper
)
{
PyGpuContextObject
*
c
=
dy
->
context
;
PyGpuContextObject
*
c
=
dy
->
context
;
cudnnStatus_t
err
;
cudnnStatus_t
err
;
...
@@ -97,9 +97,9 @@ int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
...
@@ -97,9 +97,9 @@ int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
cuda_wait
((
*
dx
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
cuda_wait
((
*
dx
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnSoftmaxBackward
(
err
=
cudnnSoftmaxBackward
(
_
handle
,
wrapper
->
handle
,
SOFTMAX_ALGO
,
wrapper
->
algo
,
SOFTMAX_MODE
,
wrapper
->
mode
,
alpha
,
alpha
,
APPLY_SPECIFIC
(
sm
),
APPLY_SPECIFIC
(
sm
),
PyGpuArray_DEV_DATA
(
sm
),
PyGpuArray_DEV_DATA
(
sm
),
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论