Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7320e1b1
提交
7320e1b1
authored
8月 28, 2015
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3288 from abergeron/nouiz_mixed
Nouiz mixed
上级
1d13344e
7f43e9f4
显示空白字符变更
内嵌
并排
正在显示
23 个修改的文件
包含
293 行增加
和
61 行删除
+293
-61
LICENSE.txt
doc/LICENSE.txt
+3
-0
index.txt
doc/index.txt
+2
-0
install.txt
doc/install.txt
+12
-0
config.txt
doc/library/config.txt
+19
-0
examples.txt
doc/tutorial/examples.txt
+2
-2
builders.py
theano/compile/builders.py
+0
-5
mode.py
theano/compile/mode.py
+6
-1
nanguardmode.py
theano/compile/nanguardmode.py
+60
-5
configdefaults.py
theano/configdefaults.py
+1
-0
optdb.py
theano/gof/optdb.py
+1
-0
test_opt.py
theano/gof/tests/test_opt.py
+0
-1
elemwise_openmp_speedup.py
theano/misc/elemwise_openmp_speedup.py
+7
-2
pycuda_example.py
theano/misc/pycuda_example.py
+5
-0
elemwise.py
theano/sandbox/cuda/elemwise.py
+11
-3
opt.py
theano/sandbox/cuda/opt.py
+16
-7
test_opt.py
theano/sandbox/cuda/tests/test_opt.py
+39
-4
elemwise.py
theano/sandbox/gpuarray/elemwise.py
+2
-0
opt.py
theano/sandbox/gpuarray/opt.py
+2
-1
test_opt.py
theano/sandbox/gpuarray/tests/test_opt.py
+70
-0
basic.py
theano/scalar/basic.py
+13
-3
opt.py
theano/tensor/opt.py
+21
-13
test_gc.py
theano/tensor/tests/test_gc.py
+1
-2
test_flake8.py
theano/tests/test_flake8.py
+0
-12
没有找到文件。
doc/LICENSE.txt
浏览文件 @
7320e1b1
...
@@ -9,6 +9,9 @@ All rights reserved.
...
@@ -9,6 +9,9 @@ All rights reserved.
Contains code from NumPy, Copyright (c) 2005-2011, NumPy Developers.
Contains code from NumPy, Copyright (c) 2005-2011, NumPy Developers.
All rights reserved.
All rights reserved.
Contain CnMeM under the same license with this copyright:
Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
modification, are permitted provided that the following conditions are met:
...
...
doc/index.txt
浏览文件 @
7320e1b1
...
@@ -21,6 +21,8 @@ Montreal).
...
@@ -21,6 +21,8 @@ Montreal).
News
News
====
====
* We added support for :ref:`CuDNN v3 <libdoc_cuda_dnn>`.
* We added support for :attr:`CNMeM <config.lib.cnmem>` to speed up
* We added support for :attr:`CNMeM <config.lib.cnmem>` to speed up
the GPU memory allocation.
the GPU memory allocation.
...
...
doc/install.txt
浏览文件 @
7320e1b1
...
@@ -308,6 +308,18 @@ to your ``Theano`` folder and execute the following command:
...
@@ -308,6 +308,18 @@ to your ``Theano`` folder and execute the following command:
You should update frequently, bugs are fixed on a very regular basis.
You should update frequently, bugs are fixed on a very regular basis.
Specific git commit
~~~~~~~~~~~~~~~~~~~
You can install a specific git commit by using the bleeding edge
instruction and adding @COMMIT_ID to the pip command like:
.. code-block:: bash
pip install --upgrade --no-deps git+git://github.com/Theano/Theano.git@07e9332a0932e90c47ed2a70fc3c7f8a55d2aa23
.. _testing_installation:
.. _testing_installation:
Testing your installation
Testing your installation
...
...
doc/library/config.txt
浏览文件 @
7320e1b1
...
@@ -705,6 +705,25 @@ import theano and print the config variable, as in:
...
@@ -705,6 +705,25 @@ import theano and print the config variable, as in:
Generate a warning when the destroy_map or view_map tell that an op work
Generate a warning when the destroy_map or view_map tell that an op work
inplace, but the op did not reuse the input for its output.
inplace, but the op did not reuse the input for its output.
.. attribute:: config.NanGuardMode.nan_is_error
Bool value, default: True
Controls whether NanGuardMode generates an error when it sees a nan.
.. attribute:: config.NanGuardMode.inf_is_error
Bool value, default: True
Controls whether NanGuardMode generates an error when it sees an inf.
.. attribute:: config.NanGuardMode.nan_is_error
Bool value, default: True
Controls whether NanGuardMode generates an error when it sees a
big value (>1e10).
.. attribute:: numpy
.. attribute:: numpy
This section contains different attributes for configuring numpy's
This section contains different attributes for configuring numpy's
...
...
doc/tutorial/examples.txt
浏览文件 @
7320e1b1
...
@@ -500,8 +500,8 @@ It will be used repeatedly.
...
@@ -500,8 +500,8 @@ It will be used repeatedly.
training_steps = 10000
training_steps = 10000
# Declare Theano symbolic variables
# Declare Theano symbolic variables
x = T.matrix("x")
x = T.
d
matrix("x")
y = T.vector("y")
y = T.
d
vector("y")
w = theano.shared(rng.randn(feats), name="w")
w = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")
b = theano.shared(0., name="b")
print "Initial model:"
print "Initial model:"
...
...
theano/compile/builders.py
浏览文件 @
7320e1b1
...
@@ -171,11 +171,6 @@ class OpFromGraph(gof.Op):
...
@@ -171,11 +171,6 @@ class OpFromGraph(gof.Op):
return
ret
return
ret
def
grad
(
self
,
inputs
,
output_grads
):
def
grad
(
self
,
inputs
,
output_grads
):
# OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will
# compute the right numerical value for the gradients but
# could fail to raise the disconnected inputs error in some
# cases.
if
hasattr
(
self
,
"grad_ops"
):
if
hasattr
(
self
,
"grad_ops"
):
grad_ops
=
self
.
grad_ops
grad_ops
=
self
.
grad_ops
else
:
else
:
...
...
theano/compile/mode.py
浏览文件 @
7320e1b1
...
@@ -387,12 +387,17 @@ def get_mode(orig_string):
...
@@ -387,12 +387,17 @@ def get_mode(orig_string):
default_mode_class
):
default_mode_class
):
return
instanciated_default_mode
return
instanciated_default_mode
if
string
in
[
'Mode'
,
'ProfileMode'
,
'DebugMode'
]:
if
string
in
[
'Mode'
,
'ProfileMode'
,
'DebugMode'
,
'NanGuardMode'
]:
if
string
==
'DebugMode'
:
if
string
==
'DebugMode'
:
# need to import later to break circular dependency.
# need to import later to break circular dependency.
from
.debugmode
import
DebugMode
from
.debugmode
import
DebugMode
# DebugMode use its own linker.
# DebugMode use its own linker.
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
elif
string
==
'NanGuardMode'
:
# need to import later to break circular dependency.
from
.nanguardmode
import
NanGuardMode
# DebugMode use its own linker.
ret
=
NanGuardMode
(
True
,
True
,
True
,
optimizer
=
config
.
optimizer
)
else
:
else
:
# This might be required if the string is 'ProfileMode'
# This might be required if the string is 'ProfileMode'
from
.profilemode
import
ProfileMode
# noqa
from
.profilemode
import
ProfileMode
# noqa
...
...
theano/compile/nanguardmode.py
浏览文件 @
7320e1b1
import
logging
import
collections
import
collections
import
logging
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano.configparser
import
config
,
AddConfigVar
,
BoolParam
import
theano.tensor
as
T
import
theano.tensor
as
T
import
theano.sandbox.cuda
as
cuda
import
theano.sandbox.cuda
as
cuda
from
theano.compile
import
Mode
from
theano.compile
import
Mode
AddConfigVar
(
'NanGuardMode.nan_is_error'
,
"Default value for nan_is_error"
,
BoolParam
(
True
),
in_c_key
=
False
)
AddConfigVar
(
'NanGuardMode.inf_is_error'
,
"Default value for inf_is_error"
,
BoolParam
(
True
),
in_c_key
=
False
)
AddConfigVar
(
'NanGuardMode.big_is_error'
,
"Default value for big_is_error"
,
BoolParam
(
True
),
in_c_key
=
False
)
logger
=
logging
.
getLogger
(
"theano.compile.nanguardmode"
)
logger
=
logging
.
getLogger
(
"theano.compile.nanguardmode"
)
...
@@ -110,26 +128,60 @@ class NanGuardMode(Mode):
...
@@ -110,26 +128,60 @@ class NanGuardMode(Mode):
big_is_error : bool
big_is_error : bool
If True, raise an error when a value greater than 1e10 is encountered.
If True, raise an error when a value greater than 1e10 is encountered.
Note
----
We ignore the linker parameter
"""
"""
# We currently loose the 3 first params frequently, when calling
# mode.including() and variant.
def
__init__
(
self
,
nan_is_error
=
None
,
inf_is_error
=
None
,
big_is_error
=
None
,
optimizer
=
None
,
linker
=
None
):
self
.
provided_optimizer
=
optimizer
cuda_compile_failed
=
False
if
nan_is_error
is
None
:
nan_is_error
=
config
.
NanGuardMode
.
nan_is_error
if
inf_is_error
is
None
:
inf_is_error
=
config
.
NanGuardMode
.
inf_is_error
if
big_is_error
is
None
:
big_is_error
=
config
.
NanGuardMode
.
big_is_error
assert
nan_is_error
or
inf_is_error
or
big_is_error
def
__init__
(
self
,
nan_is_error
,
inf_is_error
,
big_is_error
=
True
):
if
cuda
.
cuda_available
:
if
cuda
.
cuda_available
:
self
.
guard_input
=
cuda
.
fvector
(
'nan_guard'
)
self
.
guard_input
=
cuda
.
fvector
(
'nan_guard'
)
if
nan_is_error
or
inf_is_error
:
if
nan_is_error
or
inf_is_error
:
try
:
self
.
gpumin
=
theano
.
function
(
self
.
gpumin
=
theano
.
function
(
[
self
.
guard_input
],
T
.
min
(
self
.
guard_input
),
[
self
.
guard_input
],
T
.
min
(
self
.
guard_input
),
mode
=
'FAST_RUN'
mode
=
'FAST_RUN'
)
)
if
inf_is_error
:
except
RuntimeError
:
# This can happen if cuda is available, but the
# device is in exclusive mode and used by another
# process.
cuda_compile_failed
=
True
if
inf_is_error
and
not
cuda_compile_failed
:
try
:
self
.
gpumax
=
theano
.
function
(
self
.
gpumax
=
theano
.
function
(
[
self
.
guard_input
],
T
.
max
(
self
.
guard_input
),
[
self
.
guard_input
],
T
.
max
(
self
.
guard_input
),
mode
=
'FAST_RUN'
mode
=
'FAST_RUN'
)
)
if
big_is_error
:
except
RuntimeError
:
# This can happen if cuda is available, but the
# device is in exclusive mode and used by another
# process.
cuda_compile_failed
=
True
if
big_is_error
and
not
cuda_compile_failed
:
try
:
self
.
gpuabsmax
=
theano
.
function
(
self
.
gpuabsmax
=
theano
.
function
(
[
self
.
guard_input
],
T
.
max
(
T
.
abs_
(
self
.
guard_input
)),
[
self
.
guard_input
],
T
.
max
(
T
.
abs_
(
self
.
guard_input
)),
mode
=
'FAST_RUN'
mode
=
'FAST_RUN'
)
)
except
RuntimeError
:
# This can happen if cuda is available, but the
# device is in exclusive mode and used by another
# process.
cuda_compile_failed
=
True
def
do_check_on
(
var
,
nd
,
f
,
is_input
):
def
do_check_on
(
var
,
nd
,
f
,
is_input
):
"""
"""
...
@@ -154,6 +206,9 @@ class NanGuardMode(Mode):
...
@@ -154,6 +206,9 @@ class NanGuardMode(Mode):
if
nan_is_error
:
if
nan_is_error
:
err
=
False
err
=
False
if
cuda
.
cuda_available
and
isinstance
(
var
,
cuda
.
CudaNdarray
):
if
cuda
.
cuda_available
and
isinstance
(
var
,
cuda
.
CudaNdarray
):
if
not
isinstance
(
nd
.
op
,
# It store ints in float container
theano
.
sandbox
.
rng_mrg
.
GPU_mrg_uniform
):
err
=
np
.
isnan
(
self
.
gpumin
(
var
.
reshape
(
var
.
size
)))
err
=
np
.
isnan
(
self
.
gpumin
(
var
.
reshape
(
var
.
size
)))
else
:
else
:
err
=
contains_nan
(
var
)
err
=
contains_nan
(
var
)
...
@@ -227,4 +282,4 @@ class NanGuardMode(Mode):
...
@@ -227,4 +282,4 @@ class NanGuardMode(Mode):
wrap_linker
=
theano
.
gof
.
WrapLinker
([
theano
.
gof
.
OpWiseCLinker
()],
wrap_linker
=
theano
.
gof
.
WrapLinker
([
theano
.
gof
.
OpWiseCLinker
()],
nan_check
)
nan_check
)
super
(
NanGuardMode
,
self
)
.
__init__
(
wrap_linker
,
super
(
NanGuardMode
,
self
)
.
__init__
(
wrap_linker
,
optimizer
=
theano
.
config
.
optimizer
)
optimizer
=
self
.
provided_
optimizer
)
theano/configdefaults.py
浏览文件 @
7320e1b1
...
@@ -150,6 +150,7 @@ AddConfigVar(
...
@@ -150,6 +150,7 @@ AddConfigVar(
'mode'
,
'mode'
,
"Default compilation mode"
,
"Default compilation mode"
,
EnumStr
(
'Mode'
,
'ProfileMode'
,
'DebugMode'
,
'FAST_RUN'
,
EnumStr
(
'Mode'
,
'ProfileMode'
,
'DebugMode'
,
'FAST_RUN'
,
'NanGuardMode'
,
'FAST_COMPILE'
,
'PROFILE_MODE'
,
'DEBUG_MODE'
),
'FAST_COMPILE'
,
'PROFILE_MODE'
,
'DEBUG_MODE'
),
in_c_key
=
False
)
in_c_key
=
False
)
...
...
theano/gof/optdb.py
浏览文件 @
7320e1b1
...
@@ -290,6 +290,7 @@ class SequenceDB(DB):
...
@@ -290,6 +290,7 @@ class SequenceDB(DB):
def
register
(
self
,
name
,
obj
,
position
,
*
tags
):
def
register
(
self
,
name
,
obj
,
position
,
*
tags
):
super
(
SequenceDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
super
(
SequenceDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
assert
isinstance
(
position
,
(
int
,
float
))
self
.
__position__
[
name
]
=
position
self
.
__position__
[
name
]
=
position
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
...
...
theano/gof/tests/test_opt.py
浏览文件 @
7320e1b1
...
@@ -6,7 +6,6 @@ from theano.gof.opt import * # noqa
...
@@ -6,7 +6,6 @@ from theano.gof.opt import * # noqa
from
theano.gof.fg
import
FunctionGraph
as
Env
from
theano.gof.fg
import
FunctionGraph
as
Env
from
theano.gof.toolbox
import
*
# noqa
from
theano.gof.toolbox
import
*
# noqa
from
theano.tensor.opt
import
Assert
from
theano
import
tensor
as
T
from
theano
import
tensor
as
T
...
...
theano/misc/elemwise_openmp_speedup.py
浏览文件 @
7320e1b1
...
@@ -49,7 +49,12 @@ if __name__ == '__main__':
...
@@ -49,7 +49,12 @@ if __name__ == '__main__':
else
:
else
:
costlySpeed
=
costlyTimeOpenmp
/
costlyTime
costlySpeed
=
costlyTimeOpenmp
/
costlyTime
costlySpeedstring
=
"slowdown"
costlySpeedstring
=
"slowdown"
print
(
"Timed with vector of
%
d elements"
%
options
.
N
)
print
(
"Fast op time without openmp
%
fs with openmp
%
fs
%
s
%2.2
f"
%
(
cheapTime
,
cheapTimeOpenmp
,
cheapSpeedstring
,
cheapSpeed
))
print
(
"Fast op time without openmp
%
fs with openmp
%
fs
%
s
%2.2
f"
%
(
cheapTime
,
cheapTimeOpenmp
,
cheapSpeedstring
,
cheapSpeed
))
print
(
"Fast op time without openmp
%
fs with openmp
%
fs
%
s
%2.2
f"
%
(
cheapTime
,
cheapTimeOpenmp
,
cheapSpeedstring
,
cheapSpeed
))
print
(
"Slow op time without openmp
%
fs with openmp
%
fs
%
s
%2.2
f"
%
(
costlyTime
,
costlyTimeOpenmp
,
costlySpeedstring
,
costlySpeed
))
print
(
"Slow op time without openmp
%
fs with openmp
%
fs
%
s
%2.2
f"
%
(
costlyTime
,
costlyTimeOpenmp
,
costlySpeedstring
,
costlySpeed
))
theano/misc/pycuda_example.py
浏览文件 @
7320e1b1
...
@@ -285,6 +285,11 @@ class PycudaElemwiseSourceModuleMakeThunkOp(Op):
...
@@ -285,6 +285,11 @@ class PycudaElemwiseSourceModuleMakeThunkOp(Op):
self
.
scalar_op
=
scalar_op
self
.
scalar_op
=
scalar_op
self
.
inplace_pattern
=
inplace_pattern
self
.
inplace_pattern
=
inplace_pattern
# As we have a dict in props, we need to implement __hash__
def
__hash__
(
self
):
return
hash
(
type
(
self
),
hash
(
self
.
scalar_op
),
hash_from_dict
(
self
.
inplace_pattern
))
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
name
is
None
:
if
self
.
name
is
None
:
if
self
.
inplace_pattern
:
if
self
.
inplace_pattern
:
...
...
theano/sandbox/cuda/elemwise.py
浏览文件 @
7320e1b1
...
@@ -66,7 +66,7 @@ class NaiveAlgo(object):
...
@@ -66,7 +66,7 @@ class NaiveAlgo(object):
def
cache_version
(
self
):
def
cache_version
(
self
):
ver
=
self
.
scalar_op
.
c_code_cache_version
()
ver
=
self
.
scalar_op
.
c_code_cache_version
()
if
ver
:
if
ver
:
return
(
1
7
,
self
.
verbose
,
self
.
sync
,
ver
)
return
(
1
8
,
self
.
verbose
,
self
.
sync
,
ver
)
else
:
else
:
return
ver
return
ver
...
@@ -142,6 +142,8 @@ class NaiveAlgo(object):
...
@@ -142,6 +142,8 @@ class NaiveAlgo(object):
# perform the scalar operation on the input and output references
# perform the scalar operation on the input and output references
# TODO: What if the scalar_op needs support_code??
# TODO: What if the scalar_op needs support_code??
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
(
"npy_
%
s o
%
d_i;"
%
(
i
.
dtype
,
ipos
),
file
=
sio
)
task_code
=
self
.
scalar_op
.
c_code
(
task_code
=
self
.
scalar_op
.
c_code
(
Apply
(
self
.
scalar_op
,
Apply
(
self
.
scalar_op
,
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)
.
make_variable
()
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)
.
make_variable
()
...
@@ -150,9 +152,11 @@ class NaiveAlgo(object):
...
@@ -150,9 +152,11 @@ class NaiveAlgo(object):
for
output
in
node
.
outputs
]),
for
output
in
node
.
outputs
]),
nodename
+
'_scalar_'
,
nodename
+
'_scalar_'
,
get_str_list_logical_scalar
(
node
),
get_str_list_logical_scalar
(
node
),
[
'
ii_o
%
i_data[0]
'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)],
[
'
o
%
i_i
'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)],
sub
=
dict
(
fail
=
'return;'
))
# TODO: set a failure code somehow!!!
sub
=
dict
(
fail
=
'return;'
))
# TODO: set a failure code somehow!!!
print
(
" "
,
task_code
,
file
=
sio
)
print
(
" "
,
task_code
,
file
=
sio
)
for
ipos
,
_
in
enumerate
(
node
.
outputs
):
print
(
"o
%
i_data[i] = o
%
i_i;"
%
(
ipos
,
ipos
),
file
=
sio
)
print
(
" }"
,
file
=
sio
)
print
(
" }"
,
file
=
sio
)
#indent = " "*(4*d+7)
#indent = " "*(4*d+7)
...
@@ -477,6 +481,8 @@ class NaiveAlgo(object):
...
@@ -477,6 +481,8 @@ class NaiveAlgo(object):
print
(
" for (int i = idx; i < numEls; i += numThreads) {"
,
file
=
sio
)
print
(
" for (int i = idx; i < numEls; i += numThreads) {"
,
file
=
sio
)
# perform the scalar operation on the input and output references
# perform the scalar operation on the input and output references
# TODO: What if the scalar_op needs support_code??
# TODO: What if the scalar_op needs support_code??
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
(
"npy_
%
s o
%
d_i;"
%
(
i
.
dtype
,
ipos
),
file
=
sio
)
task_code
=
self
.
scalar_op
.
c_code
(
task_code
=
self
.
scalar_op
.
c_code
(
Apply
(
self
.
scalar_op
,
Apply
(
self
.
scalar_op
,
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)
.
make_variable
()
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)
.
make_variable
()
...
@@ -486,9 +492,11 @@ class NaiveAlgo(object):
...
@@ -486,9 +492,11 @@ class NaiveAlgo(object):
,
nodename
+
'_scalar_'
,
nodename
+
'_scalar_'
#, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)]
#, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)]
,
get_str_list_logical_scalar
(
node
,
data_str
=
'i
%
i_data[i]'
)
,
get_str_list_logical_scalar
(
node
,
data_str
=
'i
%
i_data[i]'
)
,
[
'o
%
i_
data[i]
'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)]
,
[
'o
%
i_
i
'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)]
,
sub
=
dict
(
fail
=
'return;'
))
# TODO: set a failure code somehow!!!
,
sub
=
dict
(
fail
=
'return;'
))
# TODO: set a failure code somehow!!!
print
(
" "
,
task_code
,
file
=
sio
)
print
(
" "
,
task_code
,
file
=
sio
)
for
ipos
,
_
in
enumerate
(
node
.
outputs
):
print
(
"o
%
i_data[i] = o
%
i_i;"
%
(
ipos
,
ipos
),
file
=
sio
)
print
(
" }"
,
file
=
sio
)
print
(
" }"
,
file
=
sio
)
print
(
"}"
,
file
=
sio
)
print
(
"}"
,
file
=
sio
)
...
...
theano/sandbox/cuda/opt.py
浏览文件 @
7320e1b1
...
@@ -279,7 +279,8 @@ def local_gpu_elemwise_0(node):
...
@@ -279,7 +279,8 @@ def local_gpu_elemwise_0(node):
# TODO: change this when fusion makes Elemwise with
# TODO: change this when fusion makes Elemwise with
# multiple outputs
# multiple outputs
gpu_elemwise
=
new_op
(
*
(
gpu_from_host
(
i
)
gpu_elemwise
=
new_op
(
*
(
gpu_from_host
(
i
)
for
i
in
node
.
inputs
))
for
i
in
node
.
inputs
),
return_list
=
True
)
# case 2 - it is still ok if some inputs were upcast to float32
# case 2 - it is still ok if some inputs were upcast to float32
elif
all
([
i
.
type
.
dtype
in
upcastable
elif
all
([
i
.
type
.
dtype
in
upcastable
for
i
in
node
.
inputs
]):
for
i
in
node
.
inputs
]):
...
@@ -292,18 +293,19 @@ def local_gpu_elemwise_0(node):
...
@@ -292,18 +293,19 @@ def local_gpu_elemwise_0(node):
new_inputs
=
[
gpu_from_host
(
tensor
.
cast
(
i
,
'float32'
))
new_inputs
=
[
gpu_from_host
(
tensor
.
cast
(
i
,
'float32'
))
for
i
in
node
.
inputs
]
for
i
in
node
.
inputs
]
gpu_elemwise
=
new_op
(
*
new_inputs
)
gpu_elemwise
=
new_op
(
*
new_inputs
,
return_list
=
True
)
else
:
else
:
return
False
return
False
else
:
else
:
return
False
return
False
gpu_elemwise
=
split_huge_add_or_mul
(
gpu_elemwise
.
owner
)
gpu_elemwise
=
split_huge_add_or_mul
(
gpu_elemwise
[
0
]
.
owner
)
if
not
gpu_elemwise
:
if
not
gpu_elemwise
:
return
False
return
False
if
max_inputs_to_GpuElemwise
(
node
)
<
len
(
gpu_elemwise
.
inputs
):
if
(
max_inputs_to_GpuElemwise
(
node
)
<
len
(
gpu_elemwise
.
inputs
)):
return
False
return
False
return
[
host_from_gpu
(
gpu_elemwise
.
outputs
[
0
])
]
return
[
host_from_gpu
(
out
)
for
out
in
gpu_elemwise
.
outputs
]
@register_opt
()
@register_opt
()
...
@@ -785,7 +787,7 @@ def local_gpu_careduce(node):
...
@@ -785,7 +787,7 @@ def local_gpu_careduce(node):
x
,
=
node
.
inputs
x
,
=
node
.
inputs
# Otherwise, is some corner case, we will try to move it
# Otherwise, is some corner case, we will try to move it
# to the GPU later and this cause not wanted user warning.
# to the GPU later and this cause not wanted user warning.
if
x
.
dtype
!=
'float32'
:
if
x
.
dtype
!=
'float32'
or
node
.
outputs
[
0
]
.
dtype
!=
"float32"
:
return
return
replace
=
False
replace
=
False
if
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
HostFromGpu
):
if
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
HostFromGpu
):
...
@@ -1114,6 +1116,13 @@ def local_gpu_incsubtensor(node):
...
@@ -1114,6 +1116,13 @@ def local_gpu_incsubtensor(node):
incsubt
=
host_output
.
owner
.
op
incsubt
=
host_output
.
owner
.
op
x
,
y
=
host_output
.
owner
.
inputs
[
0
:
2
]
x
,
y
=
host_output
.
owner
.
inputs
[
0
:
2
]
coords
=
host_output
.
owner
.
inputs
[
2
:]
coords
=
host_output
.
owner
.
inputs
[
2
:]
if
x
.
dtype
!=
"float32"
:
return
if
y
.
dtype
!=
"float32"
:
# The IncSubtensor upcast to float32 y, so we do it
# explicitly to move it to the GPU.
y
=
y
.
astype
(
'float32'
)
return
[
GpuIncSubtensor
(
return
[
GpuIncSubtensor
(
incsubt
.
idx_list
,
incsubt
.
idx_list
,
inplace
=
incsubt
.
inplace
,
inplace
=
incsubt
.
inplace
,
...
@@ -1124,7 +1133,7 @@ def local_gpu_incsubtensor(node):
...
@@ -1124,7 +1133,7 @@ def local_gpu_incsubtensor(node):
# Incrementing a float32 x results in a float32
# Incrementing a float32 x results in a float32
# output even if y is float64, so we can downcast
# output even if y is float64, so we can downcast
# y to put it on GPU
# y to put it on GPU
if
type
(
node
.
op
)
==
tensor
.
IncSubtensor
and
\
el
if
type
(
node
.
op
)
==
tensor
.
IncSubtensor
and
\
node
.
inputs
[
0
]
.
dtype
==
"float32"
:
node
.
inputs
[
0
]
.
dtype
==
"float32"
:
x
,
y
=
node
.
inputs
[
0
:
2
]
x
,
y
=
node
.
inputs
[
0
:
2
]
assert
isinstance
(
x
.
type
,
tensor
.
TensorType
)
assert
isinstance
(
x
.
type
,
tensor
.
TensorType
)
...
...
theano/sandbox/cuda/tests/test_opt.py
浏览文件 @
7320e1b1
...
@@ -599,11 +599,11 @@ def test_local_gpu_elemwise_0():
...
@@ -599,11 +599,11 @@ def test_local_gpu_elemwise_0():
# Due to optimization order, this composite is created when all
# Due to optimization order, this composite is created when all
# the op are on the gpu.
# the op are on the gpu.
f
=
theano
.
function
([
a
,
b
,
c
],
[
a
+
b
+
c
]
,
mode
=
mode_with_gpu
)
f
=
theano
.
function
([
a
,
b
,
c
],
a
+
b
+
c
,
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
cuda
.
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
cuda
.
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
tensor
.
Elemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
tensor
.
Elemwise
)
for
node
in
topo
)
==
1
f
(
a_v
,
b_v
,
c_v
)
utt
.
assert_allclose
(
f
(
a_v
,
b_v
,
c_v
),
a_v
+
b_v
+
c_v
)
# Now test with the composite already on the cpu before we move it
# Now test with the composite already on the cpu before we move it
# to the gpu
# to the gpu
...
@@ -612,11 +612,46 @@ def test_local_gpu_elemwise_0():
...
@@ -612,11 +612,46 @@ def test_local_gpu_elemwise_0():
c_s
=
theano
.
scalar
.
float32
()
c_s
=
theano
.
scalar
.
float32
()
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
a_s
+
b_s
+
c_s
])
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
a_s
+
b_s
+
c_s
])
out_op
=
tensor
.
Elemwise
(
out_s
)
out_op
=
tensor
.
Elemwise
(
out_s
)
f
=
theano
.
function
([
a
,
b
,
c
],
[
out_op
(
a
,
b
,
c
)]
,
mode
=
mode_with_gpu
)
f
=
theano
.
function
([
a
,
b
,
c
],
out_op
(
a
,
b
,
c
)
,
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
cuda
.
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
cuda
.
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
tensor
.
Elemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
tensor
.
Elemwise
)
for
node
in
topo
)
==
1
f
(
a_v
,
b_v
,
c_v
)
utt
.
assert_allclose
(
f
(
a_v
,
b_v
,
c_v
),
a_v
+
b_v
+
c_v
)
# Test multiple output
a_s
=
theano
.
scalar
.
float32
()
a
=
tensor
.
fmatrix
()
from
theano.scalar.basic
import
identity
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
identity
(
a_s
),
identity
(
c_s
),
identity
(
b_s
)])
outs_op
=
tensor
.
Elemwise
(
out_s
)
f
=
theano
.
function
([
a
,
b
,
c
],
outs_op
(
a
,
b
,
c
),
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
cuda
.
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
tensor
.
Elemwise
)
for
node
in
topo
)
==
0
out
=
f
(
a_v
,
b_v
,
c_v
)
utt
.
assert_allclose
(
out
[
0
],
a_v
)
utt
.
assert_allclose
(
out
[
1
],
c_v
)
utt
.
assert_allclose
(
out
[
2
],
b_v
)
# Test multiple output
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
a_s
+
b_s
,
a_s
*
c_s
])
outs_op
=
tensor
.
Elemwise
(
out_s
)
f
=
theano
.
function
([
a
,
b
,
c
],
outs_op
(
a
,
b
,
c
),
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
cuda
.
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
tensor
.
Elemwise
)
for
node
in
topo
)
==
0
out
=
f
(
a_v
,
b_v
,
c_v
)
utt
.
assert_allclose
(
out
[
0
],
a_v
+
b_v
)
utt
.
assert_allclose
(
out
[
1
],
a_v
*
c_v
)
# Test non-contiguous input
c
=
cuda
.
shared_constructor
(
c_v
)
f
=
theano
.
function
([
a
,
b
],
outs_op
(
a
[::
2
],
b
[::
2
],
c
[::
2
]),
mode
=
mode_with_gpu
)
out
=
f
(
a_v
,
b_v
)
utt
.
assert_allclose
(
out
[
0
],
a_v
[::
2
]
+
b_v
[::
2
])
utt
.
assert_allclose
(
out
[
1
],
a_v
[::
2
]
*
c_v
[::
2
])
def
test_elemwise_fusion
():
def
test_elemwise_fusion
():
...
...
theano/sandbox/gpuarray/elemwise.py
浏览文件 @
7320e1b1
...
@@ -72,6 +72,8 @@ class GpuElemwise(HideC, Elemwise):
...
@@ -72,6 +72,8 @@ class GpuElemwise(HideC, Elemwise):
res
=
Elemwise
.
make_node
(
self
,
*
inputs
)
res
=
Elemwise
.
make_node
(
self
,
*
inputs
)
outputs
=
[
GpuArrayType
(
broadcastable
=
o
.
type
.
broadcastable
,
outputs
=
[
GpuArrayType
(
broadcastable
=
o
.
type
.
broadcastable
,
dtype
=
o
.
type
.
dtype
)()
for
o
in
res
.
outputs
]
dtype
=
o
.
type
.
dtype
)()
for
o
in
res
.
outputs
]
if
len
(
outputs
)
>
1
:
raise
NotImplementedError
()
inputs
=
[
as_gpuarray_variable
(
i
)
for
i
in
inputs
]
inputs
=
[
as_gpuarray_variable
(
i
)
for
i
in
inputs
]
node
=
Apply
(
self
,
inputs
,
outputs
)
node
=
Apply
(
self
,
inputs
,
outputs
)
...
...
theano/sandbox/gpuarray/opt.py
浏览文件 @
7320e1b1
...
@@ -270,7 +270,8 @@ def local_gpu_elemwise(node):
...
@@ -270,7 +270,8 @@ def local_gpu_elemwise(node):
name
=
op
.
name
name
=
op
.
name
if
name
:
if
name
:
name
=
'Gpu'
+
name
name
=
'Gpu'
+
name
if
len
(
node
.
outputs
)
>
1
:
return
res
=
GpuElemwise
(
scal_op
,
name
=
name
,
res
=
GpuElemwise
(
scal_op
,
name
=
name
,
inplace_pattern
=
copy
.
copy
(
op
.
inplace_pattern
),
inplace_pattern
=
copy
.
copy
(
op
.
inplace_pattern
),
nfunc_spec
=
op
.
nfunc_spec
)
nfunc_spec
=
op
.
nfunc_spec
)
...
...
theano/sandbox/gpuarray/tests/test_opt.py
浏览文件 @
7320e1b1
...
@@ -255,3 +255,73 @@ def test_local_gpu_subtensor():
...
@@ -255,3 +255,73 @@ def test_local_gpu_subtensor():
assert
any
([
type
(
node
.
op
)
is
tensor
.
Subtensor
for
node
in
topo
])
assert
any
([
type
(
node
.
op
)
is
tensor
.
Subtensor
for
node
in
topo
])
assert
not
any
([
isinstance
(
node
.
op
,
GpuSubtensor
)
for
node
in
topo
])
assert
not
any
([
isinstance
(
node
.
op
,
GpuSubtensor
)
for
node
in
topo
])
assert
any
([
isinstance
(
node
.
op
,
GpuElemwise
)
for
node
in
topo
])
assert
any
([
isinstance
(
node
.
op
,
GpuElemwise
)
for
node
in
topo
])
def
test_local_gpu_elemwise
():
"""
Test local_gpu_elemwise when there is a dtype upcastable to float32
"""
a
=
tensor
.
bmatrix
()
b
=
tensor
.
fmatrix
()
c
=
tensor
.
fmatrix
()
a_v
=
(
numpy
.
random
.
rand
(
4
,
5
)
*
10
)
.
astype
(
"int8"
)
b_v
=
(
numpy
.
random
.
rand
(
4
,
5
)
*
10
)
.
astype
(
"float32"
)
c_v
=
(
numpy
.
random
.
rand
(
4
,
5
)
*
10
)
.
astype
(
"float32"
)
# Due to optimization order, this composite is created when all
# the op are on the gpu.
f
=
theano
.
function
([
a
,
b
,
c
],
a
+
b
+
c
,
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
type
(
node
.
op
)
==
tensor
.
Elemwise
for
node
in
topo
)
==
0
utt
.
assert_allclose
(
f
(
a_v
,
b_v
,
c_v
),
a_v
+
b_v
+
c_v
)
# Now test with the composite already on the cpu before we move it
# to the gpu
a_s
=
theano
.
scalar
.
int8
()
b_s
=
theano
.
scalar
.
float32
()
c_s
=
theano
.
scalar
.
float32
()
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
a_s
+
b_s
+
c_s
])
out_op
=
tensor
.
Elemwise
(
out_s
)
f
=
theano
.
function
([
a
,
b
,
c
],
out_op
(
a
,
b
,
c
),
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
type
(
node
.
op
)
==
tensor
.
Elemwise
for
node
in
topo
)
==
0
utt
.
assert_allclose
(
f
(
a_v
,
b_v
,
c_v
),
a_v
+
b_v
+
c_v
)
return
# Not yet implemeted
# Test multiple output
a_s
=
theano
.
scalar
.
float32
()
a
=
tensor
.
fmatrix
()
from
theano.scalar.basic
import
identity
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
identity
(
a_s
),
identity
(
c_s
),
identity
(
b_s
)])
outs_op
=
tensor
.
Elemwise
(
out_s
)
f
=
theano
.
function
([
a
,
b
,
c
],
outs_op
(
a
,
b
,
c
),
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
type
(
node
.
op
)
==
tensor
.
Elemwise
for
node
in
topo
)
==
0
out
=
f
(
a_v
,
b_v
,
c_v
)
utt
.
assert_allclose
(
out
[
0
],
a_v
)
utt
.
assert_allclose
(
out
[
1
],
c_v
)
utt
.
assert_allclose
(
out
[
2
],
b_v
)
# Test multiple output
out_s
=
theano
.
scalar
.
Composite
([
a_s
,
b_s
,
c_s
],
[
a_s
+
b_s
,
a_s
*
b_s
])
outs_op
=
tensor
.
Elemwise
(
out_s
)
f
=
theano
.
function
([
a
,
b
,
c
],
outs_op
(
a
,
b
,
c
),
mode
=
mode_with_gpu
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
GpuElemwise
)
for
node
in
topo
)
==
1
assert
sum
(
type
(
node
.
op
)
==
tensor
.
Elemwise
for
node
in
topo
)
==
0
out
=
f
(
a_v
,
b_v
,
c_v
)
utt
.
assert_allclose
(
out
[
0
],
a_v
+
b_v
)
utt
.
assert_allclose
(
out
[
1
],
a_v
*
c_v
)
# Test non-contiguous input
c
=
cuda
.
shared_constructor
(
numpy
.
asarray
(
c_v
,
dtype
=
'float32'
))
f
=
theano
.
function
([
a
,
b
],
outs_op
(
a
[::
2
],
b
[::
2
],
c
[::
2
]),
mode
=
mode_with_gpu
)
out
=
f
(
a_v
,
b_v
)
utt
.
assert_allclose
(
out
[
0
],
a_v
[::
2
]
+
b_v
[::
2
])
utt
.
assert_allclose
(
out
[
1
],
a_v
[::
2
]
*
c_v
[::
2
])
theano/scalar/basic.py
浏览文件 @
7320e1b1
...
@@ -724,7 +724,7 @@ def same_out_float_only(type):
...
@@ -724,7 +724,7 @@ def same_out_float_only(type):
class
transfer_type
(
gof
.
utils
.
object2
):
class
transfer_type
(
gof
.
utils
.
object2
):
def
__init__
(
self
,
*
transfer
):
def
__init__
(
self
,
*
transfer
):
assert
all
(
type
(
x
)
==
int
for
x
in
transfer
)
assert
all
(
type
(
x
)
in
[
int
,
str
]
or
x
is
None
for
x
in
transfer
)
self
.
transfer
=
transfer
self
.
transfer
=
transfer
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -736,6 +736,8 @@ class transfer_type(gof.utils.object2):
...
@@ -736,6 +736,8 @@ class transfer_type(gof.utils.object2):
for
i
in
self
.
transfer
:
for
i
in
self
.
transfer
:
if
i
is
None
:
if
i
is
None
:
retval
+=
[
upcast
]
retval
+=
[
upcast
]
elif
isinstance
(
i
,
str
):
retval
+=
[
i
]
else
:
else
:
retval
+=
[
types
[
i
]]
retval
+=
[
types
[
i
]]
return
retval
return
retval
...
@@ -3410,7 +3412,10 @@ class Composite(ScalarOp):
...
@@ -3410,7 +3412,10 @@ class Composite(ScalarOp):
return
lambda
inputs
:
r
.
data
return
lambda
inputs
:
r
.
data
node
=
r
.
owner
node
=
r
.
owner
producers
=
[
compose_impl
(
input
)
for
input
in
node
.
inputs
]
producers
=
[
compose_impl
(
input
)
for
input
in
node
.
inputs
]
return
lambda
inputs
:
node
.
op
.
impl
(
*
[
p
(
inputs
)
for
p
in
producers
])
def
f
(
inputs
):
return
node
.
op
.
impl
(
*
[
p
(
inputs
)
for
p
in
producers
])
return
f
self
.
_impls
=
[
compose_impl
(
r
)
for
r
in
self
.
fgraph
.
outputs
]
self
.
_impls
=
[
compose_impl
(
r
)
for
r
in
self
.
fgraph
.
outputs
]
def
init_name
(
self
):
def
init_name
(
self
):
...
@@ -3467,6 +3472,8 @@ class Composite(ScalarOp):
...
@@ -3467,6 +3472,8 @@ class Composite(ScalarOp):
# that will flatten Composite. We don't need to do this
# that will flatten Composite. We don't need to do this
# recusively, as the way the fusion optimizer work, we have
# recusively, as the way the fusion optimizer work, we have
# only 1 new Composite each time at the output.
# only 1 new Composite each time at the output.
for
i
in
inputs
:
assert
i
not
in
outputs
# This isn't supported, use identity
if
len
(
outputs
)
>
1
or
not
any
([
isinstance
(
var
.
owner
.
op
,
Composite
)
if
len
(
outputs
)
>
1
or
not
any
([
isinstance
(
var
.
owner
.
op
,
Composite
)
for
var
in
outputs
]):
for
var
in
outputs
]):
# No inner Composite
# No inner Composite
...
@@ -3538,8 +3545,11 @@ class Composite(ScalarOp):
...
@@ -3538,8 +3545,11 @@ class Composite(ScalarOp):
def
impl
(
self
,
*
inputs
):
def
impl
(
self
,
*
inputs
):
output_storage
=
[[
None
]
for
i
in
xrange
(
self
.
nout
)]
output_storage
=
[[
None
]
for
i
in
xrange
(
self
.
nout
)]
self
.
perform
(
None
,
inputs
,
output_storage
)
self
.
perform
(
None
,
inputs
,
output_storage
)
ret
urn
utils
.
to_return_values
([
storage
[
0
]
for
storage
in
ret
=
utils
.
to_return_values
([
storage
[
0
]
for
storage
in
output_storage
])
output_storage
])
if
self
.
nout
>
1
:
ret
=
tuple
(
ret
)
return
ret
def
grad
(
self
,
inputs
,
output_grads
):
def
grad
(
self
,
inputs
,
output_grads
):
raise
NotImplementedError
(
"grad is not implemented for Composite"
)
raise
NotImplementedError
(
"grad is not implemented for Composite"
)
...
...
theano/tensor/opt.py
浏览文件 @
7320e1b1
...
@@ -296,6 +296,7 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -296,6 +296,7 @@ def inplace_elemwise_optimizer_op(OP):
# gpuarray GpuElemwise inherit from Elemwise
# gpuarray GpuElemwise inherit from Elemwise
if
not
type
(
op
)
==
OP
:
if
not
type
(
op
)
==
OP
:
continue
continue
baseline
=
op
.
inplace_pattern
baseline
=
op
.
inplace_pattern
protected_inputs
=
[
protected_inputs
=
[
f
.
protected
for
f
in
node
.
fgraph
.
_features
if
f
.
protected
for
f
in
node
.
fgraph
.
_features
if
...
@@ -331,8 +332,8 @@ def inplace_elemwise_optimizer_op(OP):
...
@@ -331,8 +332,8 @@ def inplace_elemwise_optimizer_op(OP):
if
hasattr
(
op
.
scalar_op
,
"make_new_inplace"
):
if
hasattr
(
op
.
scalar_op
,
"make_new_inplace"
):
new_scal
=
op
.
scalar_op
.
make_new_inplace
(
new_scal
=
op
.
scalar_op
.
make_new_inplace
(
scalar
.
transfer_type
(
scalar
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
Non
e
)
*
[
inplace_pattern
.
get
(
i
,
o
.
dtyp
e
)
for
i
in
xrange
(
len
(
node
.
outputs
)
)]))
for
i
,
o
in
enumerate
(
node
.
outputs
)]))
else
:
else
:
new_scal
=
op
.
scalar_op
.
__class__
(
new_scal
=
op
.
scalar_op
.
__class__
(
scalar
.
transfer_type
(
scalar
.
transfer_type
(
...
@@ -1507,7 +1508,11 @@ def local_subtensor_make_vector(node):
...
@@ -1507,7 +1508,11 @@ def local_subtensor_make_vector(node):
# Python 2.4 wants to index only with Python integers
# Python 2.4 wants to index only with Python integers
v
=
int
(
v
)
v
=
int
(
v
)
# We don't need to copy over any stack traces here
# We don't need to copy over any stack traces here
return
[
x
.
owner
.
inputs
[
v
]]
try
:
ret
=
[
x
.
owner
.
inputs
[
v
]]
except
IndexError
:
raise
NotScalarConstantError
(
"Bad user graph!"
)
return
ret
except
NotScalarConstantError
:
except
NotScalarConstantError
:
pass
pass
elif
idx
.
ndim
==
1
and
isinstance
(
idx
,
T
.
Constant
):
elif
idx
.
ndim
==
1
and
isinstance
(
idx
,
T
.
Constant
):
...
@@ -5867,15 +5872,17 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5867,15 +5872,17 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
tmp_s_input
.
append
(
tmp
)
tmp_s_input
.
append
(
tmp
)
tmp_input
.
append
(
ii
)
tmp_input
.
append
(
ii
)
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
s_op
=
i
.
owner
.
op
.
scalar_op
(
*
tmp_s_input
)
s_op
=
i
.
owner
.
op
.
scalar_op
(
*
tmp_s_input
,
return_list
=
True
)
# if the scalar_op don't have a c implementation,
# if the scalar_op don't have a c implementation,
# we skip its fusion to allow the fusion of the
# we skip its fusion to allow the fusion of the
# other ops.
# other ops.
i
.
owner
.
op
.
scalar_op
.
c_code
(
s_op
.
owner
,
i
.
owner
.
op
.
scalar_op
.
c_code
(
s_op
[
0
]
.
owner
,
"test_presence_of_c_code"
,
"test_presence_of_c_code"
,
[
"x"
for
x
in
i
.
owner
.
inputs
],
[
"x"
for
x
in
i
.
owner
.
inputs
],
"z"
,
{})
[
"z"
for
z
in
i
.
owner
.
outputs
],
{})
except
MethodNotDefined
:
except
MethodNotDefined
:
catch
=
True
catch
=
True
except
NotImplementedError
:
except
NotImplementedError
:
...
@@ -5906,7 +5913,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5906,7 +5913,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
new_nb_input
=
new_nb_input_
new_nb_input
=
new_nb_input_
inputs
.
extend
(
tmp_input
)
inputs
.
extend
(
tmp_input
)
s_inputs
.
extend
(
tmp_scalar
)
s_inputs
.
extend
(
tmp_scalar
)
s_g
.
app
end
(
s_op
)
s_g
.
ext
end
(
s_op
)
else
:
else
:
# We must support the case where the same variable appear many
# We must support the case where the same variable appear many
# time in the inputs
# time in the inputs
...
@@ -5934,25 +5941,26 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
...
@@ -5934,25 +5941,26 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
fusion optimization. We skip this optimization. You can ignore this message,
fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower."""
)
your code will run correctly, but may be slower."""
)
s_new_out
=
node
.
op
.
scalar_op
(
*
s_g
)
s_new_out
=
node
.
op
.
scalar_op
(
*
s_g
,
return_list
=
True
)
try
:
try
:
s_new_out
.
owner
.
op
.
c_code
(
s_new_out
.
owner
,
s_new_out
[
0
]
.
owner
.
op
.
c_code
(
s_new_out
[
0
]
.
owner
,
"test_presence_of_c_code"
,
"test_presence_of_c_code"
,
[
"x"
for
x
in
s_g
],
[
"x"
for
x
in
s_g
],
"z"
,
{})
[
"z"
for
x
in
s_new_out
]
,
{})
except
MethodNotDefined
:
except
MethodNotDefined
:
_logger
.
info
((
"
%
s does not implement the c_code function."
_logger
.
info
((
"
%
s does not implement the c_code function."
" As well as being potentially slow, this disables "
" As well as being potentially slow, this disables "
"loop fusion of this op."
)
%
str
(
s_new_out
.
owner
.
op
))
"loop fusion of this op."
)
%
str
(
s_new_out
[
0
]
.
owner
.
op
))
return
False
return
False
except
NotImplementedError
:
except
NotImplementedError
:
_logger
.
info
((
"
%
s does not implement the c_code function. As well"
_logger
.
info
((
"
%
s does not implement the c_code function. As well"
" as being potentially slow, this disables loop"
" as being potentially slow, this disables loop"
" fusion of this op."
)
%
str
(
s_new_out
.
owner
.
op
))
" fusion of this op."
)
%
str
(
s_new_out
[
0
]
.
owner
.
op
))
return
False
return
False
# create the composite op.
# create the composite op.
C
=
scalar
.
Composite
(
s_inputs
,
[
s_new_out
]
)
C
=
scalar
.
Composite
(
s_inputs
,
s_new_out
)
# create the new node.
# create the new node.
# Do not call make_node to have test_value
# Do not call make_node to have test_value
...
...
theano/tensor/tests/test_gc.py
浏览文件 @
7320e1b1
import
sys
import
numpy
import
numpy
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
from
six.moves
import
xrange
from
six.moves
import
xrange
...
@@ -120,4 +119,4 @@ def test_merge_opt_runtime():
...
@@ -120,4 +119,4 @@ def test_merge_opt_runtime():
dt
=
time
.
time
()
-
t
dt
=
time
.
time
()
-
t
# it should never take longer than 5 seconds to compile this graph
# it should never take longer than 5 seconds to compile this graph
assert
dt
<
5.0
assert
dt
<
5.0
,
dt
theano/tests/test_flake8.py
浏览文件 @
7320e1b1
...
@@ -205,18 +205,6 @@ whitelist_flake8 = [
...
@@ -205,18 +205,6 @@ whitelist_flake8 = [
"sparse/sandbox/sp.py"
,
"sparse/sandbox/sp.py"
,
"gof/unify.py"
,
"gof/unify.py"
,
"gof/__init__.py"
,
"gof/__init__.py"
,
"gof/tests/test_cmodule.py"
,
"gof/tests/test_destroyhandler.py"
,
"gof/tests/test_opt.py"
,
"gof/tests/test_lazy.py"
,
"gof/tests/test_toolbox.py"
,
"gof/tests/test_link.py"
,
"gof/tests/test_fg.py"
,
"gof/tests/test_sched.py"
,
"gof/tests/test_graph_opt_caching.py"
,
"gof/tests/test_graph.py"
,
"gof/tests/test_cc.py"
,
"gof/tests/test_compute_test_value.py"
,
"gof/sandbox/equilibrium.py"
,
"gof/sandbox/equilibrium.py"
,
]
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论