Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7bac82c1
提交
7bac82c1
authored
11月 20, 2008
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
差异文件
Automated merge with
ssh://projects@lgcm.iro.umontreal.ca/hg/theano
上级
cfaf41e0
f28d5cb9
隐藏空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
745 行增加
和
191 行删除
+745
-191
Makefile
benchmark/autoencoder/Makefile
+1
-1
aa.cc
benchmark/autoencoder/aa.cc
+21
-3
aa.py
benchmark/autoencoder/aa.py
+19
-5
aa_numpy.py
benchmark/autoencoder/aa_numpy.py
+25
-4
__init__.py
theano/gof/__init__.py
+3
-2
cc.py
theano/gof/cc.py
+13
-7
graph.py
theano/gof/graph.py
+14
-2
link.py
theano/gof/link.py
+26
-8
opt.py
theano/gof/opt.py
+137
-38
optdb.py
theano/gof/optdb.py
+1
-1
test_opt.py
theano/gof/tests/test_opt.py
+2
-2
gradient.py
theano/gradient.py
+5
-4
printing.py
theano/printing.py
+21
-11
wraplinker.py
theano/sandbox/wraplinker.py
+41
-10
basic.py
theano/tensor/basic.py
+3
-31
elemwise.py
theano/tensor/elemwise.py
+72
-4
inplace.py
theano/tensor/inplace.py
+6
-4
opt.py
theano/tensor/opt.py
+335
-4
test_basic.py
theano/tensor/tests/test_basic.py
+0
-50
没有找到文件。
benchmark/autoencoder/Makefile
浏览文件 @
7bac82c1
aa.x
:
aa.cc
g++
-O3
-ffast-math
aa.cc
-o
aa.x
-L
${
PUB_PREFIX
}
/lib
-lgsl
-lcblas
-lgoto
-lgfortran
-lm
g++
-O3
-ffast-math
aa.cc
-o
aa.x
-L
${
PUB_PREFIX
}
/lib
-lgsl
${
THEANO_BLAS_LDFLAGS
}
clean
:
rm
aa.x
benchmark/autoencoder/aa.cc
浏览文件 @
7bac82c1
...
...
@@ -28,6 +28,7 @@ int main(int argc, char **argv)
int
neg
=
strtol
(
argv
[
1
],
0
,
0
);
int
nout
=
strtol
(
argv
[
2
],
0
,
0
);
int
nin
=
nout
;
int
nhid
=
strtol
(
argv
[
3
],
0
,
0
);
int
niter
=
strtol
(
argv
[
4
],
0
,
0
);
double
lr
=
0.01
;
...
...
@@ -35,8 +36,8 @@ int main(int argc, char **argv)
gsl_rng_set
(
rng
,
234
);
gsl_matrix
*
x
=
gsl_matrix_alloc
(
neg
,
n
out
);
gsl_matrix
*
w
=
gsl_matrix_alloc
(
n
out
,
nhid
);
gsl_matrix
*
x
=
gsl_matrix_alloc
(
neg
,
n
in
);
gsl_matrix
*
w
=
gsl_matrix_alloc
(
n
in
,
nhid
);
gsl_vector
*
a
=
gsl_vector_alloc
(
nhid
);
gsl_vector
*
b
=
gsl_vector_alloc
(
nout
);
gsl_matrix
*
xw
=
gsl_matrix_alloc
(
neg
,
nhid
);
...
...
@@ -59,11 +60,17 @@ int main(int argc, char **argv)
struct
timeval
tv0
,
tv1
;
struct
timeval
tdot0
,
tdot1
;
double
time_of_dot
=
0.0
;
gettimeofday
(
&
tv0
,
0
);
double
err
=
0.0
;
for
(
int
iter
=
0
;
iter
<
niter
;
++
iter
)
{
gettimeofday
(
&
tdot0
,
0
);
gsl_blas_dgemm
(
CblasNoTrans
,
CblasNoTrans
,
1.0
,
x
,
w
,
0.0
,
xw
);
gettimeofday
(
&
tdot1
,
0
);
time_of_dot
+=
pytime
(
&
tdot1
)
-
pytime
(
&
tdot0
);
for
(
int
i
=
0
;
i
<
neg
;
++
i
)
for
(
int
j
=
0
;
j
<
nhid
;
++
j
)
...
...
@@ -72,7 +79,10 @@ int main(int argc, char **argv)
hid
->
data
[
i
*
nhid
+
j
]
=
tanh
(
act
);
}
gettimeofday
(
&
tdot0
,
0
);
gsl_blas_dgemm
(
CblasNoTrans
,
CblasTrans
,
1.0
,
hid
,
w
,
0.0
,
hidwt
);
gettimeofday
(
&
tdot1
,
0
);
time_of_dot
+=
pytime
(
&
tdot1
)
-
pytime
(
&
tdot0
);
for
(
int
i
=
0
;
i
<
nout
;
++
i
)
g_b
->
data
[
i
]
=
0.0
;
err
=
0.0
;
...
...
@@ -90,8 +100,11 @@ int main(int argc, char **argv)
if
(
1
)
{
gettimeofday
(
&
tdot0
,
0
);
gsl_blas_dgemm
(
CblasNoTrans
,
CblasNoTrans
,
1.0
,
g_hidwt
,
w
,
0.0
,
g_hid
);
gsl_blas_dgemm
(
CblasTrans
,
CblasNoTrans
,
1.0
,
g_hidwt
,
hid
,
0.0
,
g_w
);
gettimeofday
(
&
tdot1
,
0
);
time_of_dot
+=
pytime
(
&
tdot1
)
-
pytime
(
&
tdot0
);
for
(
int
i
=
0
;
i
<
neg
;
++
i
)
...
...
@@ -101,14 +114,19 @@ int main(int argc, char **argv)
a
->
data
[
j
]
-=
lr
*
g_hid
->
data
[
i
*
nhid
+
j
];
}
gettimeofday
(
&
tdot0
,
0
);
gsl_blas_dgemm
(
CblasTrans
,
CblasNoTrans
,
-
lr
,
x
,
g_hid
,
1.0
,
w
);
gettimeofday
(
&
tdot1
,
0
);
time_of_dot
+=
pytime
(
&
tdot1
)
-
pytime
(
&
tdot0
);
for
(
int
i
=
0
;
i
<
nout
*
nhid
;
++
i
)
w
->
data
[
i
]
-=
lr
*
g_w
->
data
[
i
];
}
}
gettimeofday
(
&
tv1
,
0
);
fprintf
(
stdout
,
"took = %lfs to get err %lf
\n
"
,
pytime
(
&
tv1
)
-
pytime
(
&
tv0
),
0.5
*
err
);
double
total_time
=
pytime
(
&
tv1
)
-
pytime
(
&
tv0
);
fprintf
(
stdout
,
"took = %lfs to get err %lf
\n
"
,
total_time
,
0.5
*
err
);
fprintf
(
stdout
,
"... of which %.2lfs was spent in dgemm (fraction: %.2lf)
\n
"
,
time_of_dot
,
time_of_dot
/
total_time
);
//skip freeing
return
0
;
}
...
...
benchmark/autoencoder/aa.py
浏览文件 @
7bac82c1
...
...
@@ -10,6 +10,13 @@ import theano.sandbox
import
theano.sandbox.wraplinker
from
theano.compile
import
module
,
Mode
from
theano.sandbox.wraplinker
import
ProfileMode
from
theano
import
gof
,
Op
,
Apply
from
theano.tensor
import
blas
,
opt
# numpy: aa_numpy.py
# c : aa.cc
if
0
:
class
Opt
(
object
):
...
...
@@ -131,7 +138,7 @@ if 0:
self
.
merge
(
env
)
def
linker
(
print_prog
=
True
):
def
print_graph_
linker
(
print_prog
=
True
):
if
1
:
imap
=
{
None
:
'-'
}
def
blah
(
i
,
node
,
thunk
):
...
...
@@ -146,7 +153,6 @@ def linker(print_prog=True):
print
'node '
,
i
,
node
,
print
':'
.
join
([
imap
[
inp
.
owner
]
for
inp
in
node
.
inputs
])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
return
theano
.
sandbox
.
wraplinker
.
WrapLinkerMany
(
[
theano
.
gof
.
OpWiseCLinker
()],
[
theano
.
sandbox
.
wraplinker
.
run_all
...
...
@@ -184,8 +190,11 @@ class M(module.Module):
self
.
step
=
module
.
Method
([
x
],
err
,
updates
=
dict
(
updates
))
mod
=
M
()
#m = mod.make(mode='FAST_RUN')
mode
=
ProfileMode
(
optimizer
=
'fast_run'
,
linker
=
theano
.
gof
.
OpWiseCLinker
())
mode
=
'FAST_RUN'
#mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
mode
=
Mode
(
optimizer
=
'fast_run'
,
linker
=
theano
.
gof
.
OpWiseCLinker
(
nice_errors
=
True
))
mode
=
Mode
(
optimizer
=
'fast_run'
,
linker
=
'c'
)
print
mod
.
pretty
(
mode
=
mode
)
m
=
mod
.
make
(
mode
=
mode
)
neg
,
nout
,
nhid
,
niter
=
[
int
(
a
)
for
a
in
sys
.
argv
[
1
:]]
...
...
@@ -200,5 +209,10 @@ t = time.time()
for
i
in
xrange
(
niter
):
err
=
m
.
step
(
x
)
print
'time: '
,
time
.
time
()
-
t
,
'err: '
,
err
mode
.
print_summary
()
try
:
mode
.
print_summary
()
pass
except
:
pass
benchmark/autoencoder/aa_numpy.py
浏览文件 @
7bac82c1
...
...
@@ -4,6 +4,8 @@ import numpy as N
import
sys
import
time
# c: aa.cc
neg
,
nout
,
nhid
,
niter
=
[
int
(
a
)
for
a
in
sys
.
argv
[
1
:]]
lr
=
0.01
...
...
@@ -14,12 +16,20 @@ a = rng.randn(nhid) * 0.0
b
=
rng
.
randn
(
nout
)
*
0.0
x
=
(
rng
.
rand
(
neg
,
nout
)
-
0.5
)
*
1.5
dot_time
=
0.0
t
=
time
.
time
()
for
i
in
xrange
(
niter
):
hid
=
N
.
tanh
(
N
.
dot
(
x
,
w
)
+
a
)
tt
=
time
.
time
()
d
=
N
.
dot
(
x
,
w
)
dot_time
+=
time
.
time
()
-
tt
hid
=
N
.
tanh
(
d
+
a
)
out
=
N
.
tanh
(
N
.
dot
(
hid
,
w
.
T
)
+
b
)
tt
=
time
.
time
()
d
=
N
.
dot
(
hid
,
w
.
T
)
dot_time
+=
time
.
time
()
-
tt
out
=
N
.
tanh
(
d
+
b
)
g_out
=
out
-
x
err
=
0.5
*
N
.
sum
(
g_out
**
2
)
...
...
@@ -28,12 +38,23 @@ for i in xrange(niter):
b
-=
lr
*
N
.
sum
(
g_hidwt
,
axis
=
0
)
tt
=
time
.
time
()
g_hid
=
N
.
dot
(
g_hidwt
,
w
)
dot_time
+=
time
.
time
()
-
tt
g_hidin
=
g_hid
*
(
1.0
-
hid
**
2
)
w
-=
lr
*
(
N
.
dot
(
g_hidwt
.
T
,
hid
)
+
N
.
dot
(
x
.
T
,
g_hidin
))
tt
=
time
.
time
()
d
=
N
.
dot
(
g_hidwt
.
T
,
hid
)
dd
=
N
.
dot
(
x
.
T
,
g_hidin
)
dot_time
+=
time
.
time
()
-
tt
gw
=
(
d
+
dd
)
w
-=
lr
*
gw
a
-=
lr
*
N
.
sum
(
g_hidin
,
axis
=
0
)
print
'time: '
,
time
.
time
()
-
t
,
'err: '
,
err
total_time
=
time
.
time
()
-
t
print
'time: '
,
total_time
,
'err: '
,
err
print
' of which'
,
dot_time
,
'was spent on dot. Fraction:'
,
dot_time
/
total_time
theano/gof/__init__.py
浏览文件 @
7bac82c1
...
...
@@ -23,11 +23,12 @@ from op import \
from
opt
import
\
Optimizer
,
optimizer
,
SeqOptimizer
,
\
MergeOptimizer
,
MergeOptMerge
,
\
LocalOptimizer
,
local_optimizer
,
LocalOptGroup
,
LocalOpKeyOptGroup
,
\
LocalOptimizer
,
local_optimizer
,
LocalOptGroup
,
\
OpSub
,
OpRemove
,
PatternSub
,
\
NavigatorOptimizer
,
TopoOptimizer
,
OpKeyOptimizer
,
EquilibriumOptimizer
,
\
NavigatorOptimizer
,
TopoOptimizer
,
EquilibriumOptimizer
,
\
keep_going
,
warn
,
\
InplaceOptimizer
,
PureThenInplaceOptimizer
#LocalOpKeyOptGroup, OpKeyOptimizer
from
optdb
import
\
DB
,
Query
,
\
...
...
theano/gof/cc.py
浏览文件 @
7bac82c1
...
...
@@ -686,14 +686,15 @@ class CLinker(link.Linker):
instantiate
.
customize
.
add_support_code
(
support_code
)
instantiate
.
customize
.
add_support_code
(
self
.
struct_code
)
instantiate
.
customize
.
add_support_code
(
static
)
for
extra_arg
in
(
"-w"
,
#-w means supress all warnings
):
#"-O3",
#"-ffast-math",
for
extra_arg
in
(
"-O2"
,
"-ffast-math"
,
#"-fprefetch-loop-arrays",
#"-ftree-vect-loop-version",
#"-ftree-loop-optimize",
#"-ftree-vectorize"):
"-w"
#-w means supress all warnings
):
instantiate
.
customize
.
add_extra_compile_arg
(
extra_arg
)
for
arg
in
self
.
compile_args
():
instantiate
.
customize
.
add_extra_compile_arg
(
arg
)
...
...
@@ -747,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
exc_value
=
exc_type
(
_exc_value
,
task
)
exc_value
.
__thunk_trace__
=
trace
# this can be used to retrieve the location the Op was declared
raise
exc_type
,
exc_value
,
exc_trace
execute
.
cthunk
=
cthunk
return
execute
...
...
@@ -769,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker):
__cache__
=
{}
def
__init__
(
self
,
fallback_on_perform
=
True
):
def
__init__
(
self
,
fallback_on_perform
=
True
,
nice_errors
=
True
):
self
.
env
=
None
self
.
fallback_on_perform
=
fallback_on_perform
self
.
nice_errors
=
nice_errors
def
accept
(
self
,
env
,
no_recycling
=
[]):
if
self
.
env
is
not
None
and
self
.
env
is
not
env
:
...
...
@@ -841,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker):
else
:
no_recycling
=
[
storage_map
[
r
]
for
r
in
no_recycling
if
r
not
in
env
.
inputs
]
f
=
link
.
streamline
(
env
,
thunks
,
order
,
no_recycling
=
no_recycling
,
profiler
=
profiler
)
f
=
link
.
streamline
(
env
,
thunks
,
order
,
no_recycling
=
no_recycling
,
nice_errors
=
self
.
nice_errors
)
return
f
,
[
link
.
Container
(
input
,
storage
)
for
input
,
storage
in
zip
(
env
.
inputs
,
input_storage
)],
\
[
link
.
Container
(
output
,
storage
,
True
)
for
output
,
storage
in
zip
(
env
.
outputs
,
output_storage
)],
\
...
...
@@ -849,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker):
def
_default_checker
(
x
,
y
):
"""WRITEME
Default checker for DualLinker. This checks that the
...
...
theano/gof/graph.py
浏览文件 @
7bac82c1
...
...
@@ -13,6 +13,7 @@ from collections import deque
import
utils
_creation_idx
=
[
0
]
class
Apply
(
utils
.
object2
):
"""
...
...
@@ -121,6 +122,13 @@ class Apply(utils.object2):
def
__asapply__
(
self
):
return
self
def
__hash__
(
self
):
if
not
hasattr
(
self
,
'_creation_idx'
):
self
.
_creation_idx
=
_creation_idx
[
0
]
_creation_idx
[
0
]
+=
1
return
self
.
_creation_idx
def
clone
(
self
):
"""Duplicate this Apply instance with inputs = self.inputs.
...
...
@@ -567,7 +575,10 @@ def general_toposort(r_out, deps, debug_print = False):
deps(i) should behave like a pure function (no funny business with internal state)
:note:
deps(i) can/should be cached by the deps function to be fast
deps(i) will be cached by this function (to be fast)
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
"""
deps_cache
=
{}
def
_deps
(
io
):
...
...
@@ -611,8 +622,9 @@ def general_toposort(r_out, deps, debug_print = False):
def
io_toposort
(
i
,
o
,
orderings
=
{}):
"""WRITEME
"""
#the inputs are used only here in the function that decides what 'predecessors' to explore
iset
=
set
(
i
)
def
deps
(
obj
):
def
deps
(
obj
):
rval
=
[]
if
obj
not
in
iset
:
if
isinstance
(
obj
,
Result
):
...
...
theano/gof/link.py
浏览文件 @
7bac82c1
...
...
@@ -5,6 +5,7 @@ from type import Type
import
sys
,
traceback
from
copy
import
copy
from
cutils
import
run_cthunk
__excepthook
=
sys
.
excepthook
...
...
@@ -225,9 +226,27 @@ def clear_storage_thunk(stg):
thunk
.
inputs
=
[
stg
]
return
thunk
def
streamline
(
env
,
thunks
,
order
,
no_recycling
=
[],
profiler
=
None
):
"""WRITEME"""
if
profiler
is
None
:
def
streamline
(
env
,
thunks
,
order
,
no_recycling
=
[],
profiler
=
None
,
nice_errors
=
True
):
"""WRITEME
:param env:
:param thunks: the list of program instructions
:param order: the list of apply instances that gave rise to the thunks (same order as thunks)
:param no_recycling: storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
:param profiler: deprecated
:param nice_errors: run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if
profiler
is
not
None
:
raise
NotImplementedError
()
if
nice_errors
:
def
f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
...
...
@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
except
:
raise_with_op
(
node
)
else
:
# don't worry about raise_with_op, just go a little faster.
#there is a mix of python and c thunks
def
f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
def
g
():
for
thunk
,
node
in
zip
(
thunks
,
order
):
profiler
.
profile_node
(
thunk
,
node
)
profiler
.
profile_env
(
g
,
env
)
f
.
profiler
=
profiler
for
thunk
in
thunks
:
thunk
()
return
f
class
LocalLinker
(
Linker
):
...
...
theano/gof/opt.py
浏览文件 @
7bac82c1
...
...
@@ -17,6 +17,9 @@ import sys
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
env
):
return
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
)
class
Optimizer
(
object
):
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
...
...
@@ -73,7 +76,7 @@ class FromFunctionOptimizer(Optimizer):
env
.
extend
(
toolbox
.
ReplaceValidate
())
def
optimizer
(
f
):
"""
WRITEME
"""
"""
decorator for FromFunctionOptimizer
"""
return
FromFunctionOptimizer
(
f
)
...
...
@@ -137,6 +140,10 @@ class _metadict:
try
:
self
.
d
[
item
]
=
value
except
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
self
.
l
[
i
]
=
(
item
,
value
)
return
self
.
l
.
append
((
item
,
value
))
def
get
(
self
,
item
,
default
):
try
:
...
...
@@ -191,7 +198,7 @@ class MergeOptimizer(Optimizer):
cid
[
r
]
=
i
inv_cid
[
i
]
=
r
for
node
in
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
):
for
node
in
_list_of_nodes
(
env
):
node_cid
=
(
node
.
op
,
tuple
([
cid
[
input
]
for
input
in
node
.
inputs
]))
dup
=
inv_cid
.
get
(
node_cid
,
None
)
success
=
False
...
...
@@ -229,10 +236,33 @@ def MergeOptMerge(opt):
### Local Optimizers ###
########################
class
LocalOptimizer
(
Optimizer
,
utils
.
object2
):
"""WRITEME"""
class
LocalOptimizer
(
object
):
"""A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a env-based Optimizer instance.
"""
def
__hash__
(
self
):
if
not
hasattr
(
self
,
'_optimizer_idx'
):
self
.
_optimizer_idx
=
_optimizer_idx
[
0
]
_optimizer_idx
[
0
]
+=
1
return
self
.
_optimizer_idx
def
transform
(
self
,
node
):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of results> to use in place of `node`'s outputs in the greater graph.
:type node: an Apply instance
"""
raise
utils
.
AbstractFunctionError
()
...
...
@@ -272,7 +302,7 @@ class LocalOptGroup(LocalOptimizer):
return
repl
class
LocalOpKeyOptGroup
(
LocalOptGroup
):
class
_
LocalOpKeyOptGroup
(
LocalOptGroup
):
"""WRITEME"""
def
__init__
(
self
,
optimizers
):
...
...
@@ -515,9 +545,29 @@ class PatternSub(LocalOptimizer):
class
NavigatorOptimizer
(
Optimizer
):
"""WRITEME"""
"""Abstract class
"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
"""
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
self
.
ignore_newtrees
=
not
getattr
(
local_opt
,
'reentrant'
,
True
)
...
...
@@ -526,9 +576,18 @@ class NavigatorOptimizer(Optimizer):
self
.
failure_callback
=
failure_callback
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks. Keep this around so that you can detach later!
"""
if
self
.
ignore_newtrees
:
importer
=
None
if
importer
is
None
and
pruner
is
None
:
return
None
...
...
@@ -542,12 +601,18 @@ class NavigatorOptimizer(Optimizer):
if
chin
is
not
None
:
def
on_change_input
(
self
,
env
,
node
,
i
,
r
,
new_r
):
chin
(
node
,
i
,
r
,
new_r
)
u
=
Updater
()
env
.
extend
(
u
)
return
u
def
detach_updater
(
self
,
env
,
u
):
"""Undo the work of attach_updater.
:param u: a return-value of attach_updater
:returns: None.
"""
if
u
is
not
None
:
env
.
remove_feature
(
u
)
...
...
@@ -610,7 +675,7 @@ class TopoOptimizer(NavigatorOptimizer):
except
:
self
.
detach_updater
(
env
,
u
)
raise
self
.
detach_updater
(
env
,
u
)
class
OpKeyOptimizer
(
NavigatorOptimizer
):
...
...
@@ -642,6 +707,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
except
:
self
.
detach_updater
(
env
,
u
)
raise
self
.
detach_updater
(
env
,
u
)
def
add_requirements
(
self
,
env
):
"""
...
...
@@ -654,38 +720,70 @@ class OpKeyOptimizer(NavigatorOptimizer):
# class EquilibriumOptimizer(NavigatorOptimizer):
# """WRITEME"""
from
utils
import
D
# def __init__(self, local_optimizers, failure_callback = None):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
# def apply(self, env):
# op = self.local_opt.op_key()
# if isinstance(op, (list, tuple)):
# q = reduce(list.__iadd__, map(env.get_nodes, op))
# else:
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
local_optimizers
,
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
"""
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
"""
from
utils
import
D
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
self
.
local_optimizers
=
local_optimizers
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
changed
=
True
max_use_abort
=
False
process_count
=
{}
while
changed
and
not
max_use_abort
:
changed
=
False
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
max_use
=
len
(
q
)
*
self
.
max_use_ratio
def
importer
(
node
):
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
:
try
:
q
.
remove
(
node
)
except
ValueError
:
pass
u
=
self
.
attach_updater
(
env
,
importer
,
pruner
)
try
:
while
q
:
node
=
q
.
pop
()
current_node
=
node
for
lopt
in
self
.
local_optimizers
:
process_count
.
setdefault
(
lopt
,
0
)
if
process_count
[
lopt
]
>
max_use
:
max_use_abort
=
True
else
:
lopt_change
=
self
.
process_node
(
env
,
node
,
lopt
)
process_count
[
lopt
]
+=
1
if
lopt_change
else
0
changed
|=
lopt_change
except
:
self
.
detach_updater
(
env
,
u
)
raise
self
.
detach_updater
(
env
,
u
)
if
max_use_abort
:
print
>>
sys
.
stderr
,
"WARNING: EquilibriumOptimizer max'ed out"
class
_EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
local_optimizers
,
...
...
@@ -780,10 +878,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
# importer(node)
for
node
in
env
.
nodes
:
for
node
in
env
.
toposort
()
:
tasks
[
node
]
.
extend
(
lopt
for
track
,
i
,
lopt
in
self
.
fetch_tracks0
(
node
.
op
))
u
=
self
.
attach_updater
(
env
,
importer
,
pruner
,
chin
)
print
'KEYS'
,
map
(
hash
,
tasks
.
keys
())
while
tasks
:
for
node
in
tasks
.
iterkeys
():
todo
=
tasks
.
pop
(
node
)
...
...
theano/gof/optdb.py
浏览文件 @
7bac82c1
...
...
@@ -18,7 +18,7 @@ class DB(object):
# N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
)):
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
,
opt
.
LocalOptimizer
)):
raise
Exception
(
'wtf'
,
obj
)
obj
.
name
=
name
...
...
theano/gof/tests/test_opt.py
浏览文件 @
7bac82c1
...
...
@@ -375,7 +375,7 @@ class TestEquilibrium(object):
x
,
y
,
z
=
map
(
MyResult
,
'xyz'
)
e
=
op3
(
op4
(
x
,
y
))
g
=
Env
([
x
,
y
,
z
],
[
e
])
print
g
print
'before'
,
g
sys
.
stderr
=
sys
.
stdout
# display pesky warnings along with stdout
opt
=
EquilibriumOptimizer
(
[
PatternSub
((
op1
,
'x'
,
'y'
),
(
op2
,
'x'
,
'y'
)),
...
...
@@ -384,7 +384,7 @@ class TestEquilibrium(object):
],
max_use_ratio
=
1.
/
len
(
g
.
nodes
))
# each opt can only be applied once
opt
.
optimize
(
g
)
print
g
print
'after'
,
g
assert
str
(
g
)
==
'[Op4(x, y)]'
...
...
theano/gradient.py
浏览文件 @
7bac82c1
...
...
@@ -2,6 +2,7 @@ import gof #, gof.result
import
numpy
#for numeric_grad
from
gof.python25
import
all
import
gof.utils
_msg_retType
=
'op.grad(...) returned a non-list'
_msg_badlen
=
'op.grad(...) returned wrong number of gradients'
...
...
@@ -55,17 +56,17 @@ def grad_sources_inputs(sources, graph_inputs):
else
:
gmap
[
r
]
=
g_r
graph_outputs
=
g
map
.
keys
(
)
graph_outputs
=
g
of
.
utils
.
uniq
([
r
for
r
,
g
in
sources
]
)
if
graph_inputs
is
None
:
graph_inputs
=
gof
.
graph
.
inputs
(
graph_outputs
)
for
node
in
gof
.
graph
.
io_toposort
(
graph_inputs
,
graph_outputs
)
.
__reversed__
():
g_outputs
=
[
gmap
.
get
(
o
,
None
)
for
o
in
node
.
outputs
]
#if all output gradients are None, continue
if
all
(
map
(
lambda
x
:
x
is
None
,
g_outputs
)):
continue
output_arg
=
g_outputs
input_arg
=
node
.
inputs
...
...
theano/printing.py
浏览文件 @
7bac82c1
...
...
@@ -235,17 +235,27 @@ class PPrinter:
else
:
raise
TypeError
(
'Not enough arguments to call.'
)
special
=
dict
(
middle_dot
=
u"
\u00B7
"
,
big_sigma
=
u"
\u03A3
"
)
greek
=
dict
(
alpha
=
u"
\u03B1
"
,
beta
=
u"
\u03B2
"
,
gamma
=
u"
\u03B3
"
,
delta
=
u"
\u03B4
"
,
epsilon
=
u"
\u03B5
"
)
use_ascii
=
True
if
use_ascii
:
special
=
dict
(
middle_dot
=
"
\
dot"
,
big_sigma
=
"
\
Sigma"
)
greek
=
dict
(
alpha
=
"
\a
lpha"
,
beta
=
"
\b
eta"
,
gamma
=
"
\
gamma"
,
delta
=
"
\
delta"
,
epsilon
=
"
\
epsilon"
)
else
:
special
=
dict
(
middle_dot
=
u"
\u00B7
"
,
big_sigma
=
u"
\u03A3
"
)
greek
=
dict
(
alpha
=
u"
\u03B1
"
,
beta
=
u"
\u03B2
"
,
gamma
=
u"
\u03B3
"
,
delta
=
u"
\u03B4
"
,
epsilon
=
u"
\u03B5
"
)
pprint
=
PPrinter
()
...
...
theano/sandbox/wraplinker.py
浏览文件 @
7bac82c1
...
...
@@ -2,6 +2,7 @@ from __future__ import absolute_import
import
time
import
numpy
from
..gof.cutils
import
run_cthunk
from
..gof.link
import
WrapLinker
from
..compile.mode
import
Mode
...
...
@@ -107,19 +108,42 @@ class ProfileMode(Mode):
local_time
=
[
0.0
]
apply_time
=
{}
op_time
=
{}
op_cimpl
=
{}
def
blah
(
i
,
node
,
*
thunks
):
t0
=
time
.
time
()
for
th
in
thunks
:
th
()
dt
=
time
.
time
()
-
t0
if
0
:
t0
=
time
.
time
()
for
th
in
thunks
:
th
()
dt
=
time
.
time
()
-
t0
elif
0
:
#more precise timing
for
th
in
thunks
:
t0
=
time
.
time
()
th
()
dt
=
time
.
time
()
-
t0
elif
1
:
for
th
in
thunks
:
if
hasattr
(
th
,
'cthunk'
):
t0
=
time
.
time
()
run_cthunk
(
th
.
cthunk
)
dt
=
time
.
time
()
-
t0
else
:
t0
=
time
.
time
()
th
()
dt
=
time
.
time
()
-
t0
elif
1
:
pass
else
:
raise
Exception
(
'one of the cases has to run the thunks!'
)
local_time
[
0
]
+=
dt
apply_time
[(
i
,
node
.
op
)]
=
apply_time
.
get
((
i
,
node
.
op
),
0.0
)
+
dt
op_time
[
node
.
op
]
=
op_time
.
get
(
node
.
op
,
0.0
)
+
dt
op_cimpl
[
node
.
op
]
=
hasattr
(
thunks
[
0
],
'cthunk'
)
self
.
local_time
=
local_time
self
.
apply_time
=
apply_time
self
.
op_time
=
op_time
self
.
op_cimpl
=
op_cimpl
wrap_linker
=
WrapLinkerMany
([
linker
],
[
blah
])
if
optimizer
:
...
...
@@ -142,13 +166,20 @@ class ProfileMode(Mode):
atimes
.
sort
()
atimes
.
reverse
()
for
t
,
a
in
atimes
[:
15
]:
print
' '
,
t
,
a
print
' ... (ignoring
%
i other Apply instances)'
%
max
(
0
,
len
(
atimes
)
-
15
)
print
'
\t
%.3
f
\t
%
i
\t
%
s'
%
(
t
,
a
[
0
],
a
[
1
])
print
' ... (remaining
%
i Apply instances account for
%.2
f of the runtime)'
\
%
(
max
(
0
,
len
(
atimes
)
-
15
),
sum
(
t
for
t
,
a
in
atimes
[
15
:]))
n_ops_to_print
=
20
print
'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>'
otimes
=
[(
t
/
local_time
,
a
)
for
a
,
t
in
op_time
.
items
()]
otimes
=
[(
t
/
local_time
,
a
,
self
.
op_cimpl
[
a
]
)
for
a
,
t
in
op_time
.
items
()]
otimes
.
sort
()
otimes
.
reverse
()
for
t
,
a
in
otimes
[:
15
]:
print
' '
,
t
,
a
print
' ... (ignoring
%
i other kinds Ops)'
%
max
(
0
,
len
(
otimes
)
-
15
)
for
t
,
a
,
ci
in
otimes
[:
n_ops_to_print
]:
print
'
\t
%.3
f
\t
%
s
%
s'
%
(
t
,
'*'
if
ci
else
' '
,
a
)
print
' ... (remaining
%
i Ops account for
%.2
f of the runtime)'
\
%
(
max
(
0
,
len
(
otimes
)
-
n_ops_to_print
),
sum
(
t
for
t
,
a
,
ci
in
otimes
[
n_ops_to_print
:]))
print
'(*) Op is running a c implementation'
theano/tensor/basic.py
浏览文件 @
7bac82c1
...
...
@@ -1089,38 +1089,9 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations
##########################
class
TransposeInplace
(
Op
):
view_map
=
{
0
:
[
0
]}
def
make_node
(
self
,
input
):
return
Apply
(
self
,
[
input
],
[
tensor
(
dtype
=
input
.
type
.
dtype
,
broadcastable
=
reversed
(
input
.
type
.
broadcastable
))])
def
perform
(
self
,
node
,
(
x
,
),
(
z
,
)):
z
[
0
]
=
x
.
T
def
grad
(
self
,
(
x
,),
(
gz
,)):
return
transpose
(
gz
),
def
c_code
(
self
,
node
,
name
,
(
x
,
),
(
z
,
),
sub
):
return
"""
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(
%(x)
s, NULL);
if (
%(z)
s) {
Py_XDECREF(
%(z)
s);
}
%(z)
s = transposed;
"""
%
locals
()
def
__str__
(
self
):
return
"TransposeView"
_transpose_inplace
=
TransposeInplace
()
def
transpose
(
x
,
**
kwargs
):
"""WRITEME"""
return
_transpose_inplace
(
tensor_copy
(
x
),
**
kwargs
)
dims
=
range
(
x
.
ndim
-
1
,
-
1
,
-
1
)
return
DimShuffle
(
x
.
broadcastable
,
dims
,
inplace
=
True
)(
tensor_copy
(
x
))
class
Subtensor
(
Op
):
...
...
@@ -1781,6 +1752,7 @@ class Dot(Op):
# The error raised by numpy has no shape information, we mean to add that
e
.
args
=
e
.
args
+
(
x
.
shape
,
y
.
shape
)
raise
def
grad
(
self
,
(
x
,
y
),
(
gz
,)):
if
gz
.
type
.
ndim
==
0
:
return
gz
*
y
,
gz
*
x
...
...
theano/tensor/elemwise.py
浏览文件 @
7bac82c1
...
...
@@ -103,16 +103,18 @@ class DimShuffle(Op):
for
i
,
b
in
enumerate
(
input_broadcastable
):
if
i
not
in
new_order
:
# we want to drop this dimension because it's not a value in new_order
if
b
==
1
:
if
b
==
1
:
# 1 aka True
self
.
drop
.
append
(
i
)
else
:
# we cannot drop non-broadcastable dimensions
raise
NotImplemented
Error
(
"You cannot drop a non-broadcastable dimension."
)
raise
Value
Error
(
"You cannot drop a non-broadcastable dimension."
)
else
:
i2j
[
i
]
=
j
j
+=
1
# transposition of non-broadcastable dimensions
# This is how the dimensions will be permuted, without accounting for the extra
# 'x' broadcastable dimensions to insert.
self
.
shuffle
=
[
i2j
[
x
]
for
x
in
new_order
if
x
!=
'x'
]
# list of dimensions of the output that are broadcastable and were not in the original input
...
...
@@ -144,7 +146,8 @@ class DimShuffle(Op):
and
self
.
input_broadcastable
==
other
.
input_broadcastable
def
__hash__
(
self
):
return
hash
(
self
.
inplace
)
^
hash
(
self
.
new_order
)
^
hash
(
self
.
input_broadcastable
)
return
hash
(
type
(
self
))
^
hash
(
self
.
inplace
)
\
^
hash
(
self
.
new_order
)
^
hash
(
self
.
input_broadcastable
)
def
__str__
(
self
):
if
self
.
inplace
:
...
...
@@ -175,13 +178,78 @@ class DimShuffle(Op):
storage
[
0
]
=
res
def
c_code
(
self
,
node
,
name
,
(
input
,),
(
res
,),
sub
):
def
statements
(
lst
):
return
';
\n
'
.
join
(
lst
)
+
';'
nd_in
=
len
(
self
.
input_broadcastable
)
nd_out
=
len
(
self
.
new_order
)
check_input_nd
=
[(
'if (
%(input)
s->nd != '
+
str
(
nd_in
)
+
')'
'{PyErr_SetString(PyExc_NotImplementedError, "input nd");
%(fail)
s;}'
)]
clear_output
=
[
'if (
%(res)
s) {Py_XDECREF(
%(res)
s);}'
]
shape_statements
=
[
'npy_intp dimensions[
%
i]'
%
nd_out
]
shape_statements
+=
[(
'dimensions['
+
str
(
i
)
+
'] =
%(input)
s->dimensions['
+
str
(
o
)
+
']'
)
if
o
!=
'x'
else
(
'dimensions['
+
str
(
i
)
+
'] = 1'
)
for
i
,
o
in
enumerate
(
self
.
new_order
)]
strides_statements
=
[
'npy_intp strides[
%
i]'
%
nd_out
]
strides_statements
+=
[(
'strides['
+
str
(
i
)
+
'] =
%(input)
s->strides['
+
str
(
o
)
+
']'
)
if
o
!=
'x'
else
(
'strides['
+
str
(
i
)
+
'] = 0'
)
for
i
,
o
in
enumerate
(
self
.
new_order
)]
if
self
.
inplace
:
get_base
=
[
'{ PyArrayObject * base =
%(input)
s'
,
'Py_INCREF((PyObject*)base)'
]
else
:
get_base
=
[(
'{ PyArrayObject * base = (PyArrayObject*)PyArray_FromAny((PyObject*)
%(input)
s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)'
)]
alloc_output
=
[(
'
%(res)
s = (PyArrayObject*)PyArray_New(&PyArray_Type, '
''
+
str
(
nd_out
)
+
', dimensions, '
'PyArray_TYPE(base), strides, '
'base->data, base->descr->elsize, '
'PyArray_FLAGS(base), NULL)'
),
'
%(res)
s->base = (PyObject*)base'
,
'}'
]
full_code
=
statements
(
check_input_nd
+
clear_output
+
shape_statements
+
strides_statements
+
get_base
+
alloc_output
)
if
0
:
print
'C_CODE'
print
''
print
self
print
"IN BROAD"
,
self
.
input_broadcastable
print
"NEW ORDER"
,
self
.
new_order
print
"SHUFFLE"
,
self
.
shuffle
print
"AUGMENT"
,
self
.
augment
print
'------------'
print
''
print
full_code
if
0
:
import
sys
sys
.
exit
()
return
full_code
%
dict
(
locals
(),
**
sub
)
def
grad
(
self
,
(
x
,
),
(
gz
,
)):
gz
=
as_tensor
(
gz
)
grad_order
=
[
'x'
]
*
len
(
x
.
type
.
broadcastable
)
for
i
,
v
in
enumerate
(
self
.
new_order
):
if
v
!=
'x'
:
grad_order
[
v
]
=
i
return
DimShuffle
(
gz
.
type
.
broadcastable
,
grad_order
)(
gz
),
return
[
DimShuffle
(
gz
.
type
.
broadcastable
,
grad_order
,
inplace
=
True
)(
Elemwise
(
scalar
.
identity
)(
gz
))]
...
...
theano/tensor/inplace.py
浏览文件 @
7bac82c1
from
basic
import
_scal_elemwise
,
_transpose_inplace
from
basic
import
_scal_elemwise
#
, _transpose_inplace
from
..
import
scalar
as
scal
import
elemwise
from
..
import
printing
...
...
@@ -183,9 +183,11 @@ pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint
.
assign
(
pow_inplace
,
printing
.
OperatorPrinter
(
'**='
,
1
,
'right'
))
transpose_inplace
=
_transpose_inplace
"""WRITEME"""
def
transpose_inplace
(
x
,
**
kwargs
):
"""Perform a transpose on a tensor without copying the underlying storage"""
dims
=
range
(
x
.
ndim
-
1
,
-
1
,
-
1
)
return
elemwise
.
DimShuffle
(
x
.
broadcastable
,
dims
,
inplace
=
True
)(
x
)
pprint
.
assign
(
transpose_inplace
,
printing
.
MemberPrinter
(
'T'
))
#
pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
theano/tensor/opt.py
浏览文件 @
7bac82c1
...
...
@@ -50,7 +50,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
(
T
.
Subtensor
([
slice
(
0
,
1
)]),
(
T
.
shape
,
'a'
)),
(
T
.
Subtensor
([
slice
(
1
,
2
)]),
(
T
.
shape
,
'b'
)))),
T
.
constant
(
1.0
),
'a'
,
'b'
,
T
.
constant
(
1.0
)),
allow_multiple_clients
=
False
)
allow_multiple_clients
=
False
)
def
_insert_inplace_optimizer
(
env
):
...
...
@@ -216,6 +217,13 @@ register_canonicalize(local_shape_lift_dot)
################
def
encompasses_broadcastable
(
b1
,
b2
):
"""
Returns True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
:param b1: the broadcastable attribute of a tensor type
:param b2: the broadcastable attribute of a tensor type
"""
if
len
(
b1
)
<
len
(
b2
):
return
False
b1
=
b1
[
-
len
(
b2
):]
...
...
@@ -330,6 +338,7 @@ def local_fill_cut(node):
register_canonicalize
(
local_fill_cut
)
register_canonicalize
(
gof
.
OpRemove
(
T
.
tensor_copy
),
name
=
'remove_tensor_copy'
)
@gof.local_optimizer
([
None
,
T
.
fill
])
def
local_fill_sink
(
node
):
...
...
@@ -550,6 +559,7 @@ def local_neg_to_mul(node):
return
[
-
1
*
node
.
inputs
[
0
]]
else
:
return
False
register_canonicalize
(
local_neg_to_mul
)
@gof.local_optimizer
([
T
.
mul
])
def
local_mul_to_neg
(
node
):
...
...
@@ -557,6 +567,7 @@ def local_mul_to_neg(node):
return
[
-
local_mul_canonizer
.
merge_num_denum
(
node
.
inputs
[
1
:],
[])]
else
:
return
False
register_specialize
(
local_mul_to_neg
)
@gof.local_optimizer
([
T
.
div
])
def
local_div_to_inv
(
node
):
...
...
@@ -564,10 +575,88 @@ def local_div_to_inv(node):
return
[
T
.
inv
(
local_mul_canonizer
.
merge_num_denum
(
node
.
inputs
[
1
:],
[]))]
else
:
return
False
register_canonicalize
(
local_neg_to_mul
)
register_specialize
(
local_mul_to_neg
)
register_specialize
(
local_div_to_inv
)
@gof.local_optimizer
([
T
.
inv
])
def
local_inv_canon
(
node
):
if
node
.
op
==
T
.
inv
:
return
[
T
.
pow
(
node
.
inputs
[
0
],
-
1.0
)]
else
:
return
False
register_canonicalize
(
local_inv_canon
)
@gof.local_optimizer
([
T
.
pow
])
def
local_pow_canonicalize
(
node
):
if
node
.
op
==
T
.
pow
:
if
N
.
all
(
local_mul_canonizer
.
get_constant
(
node
.
inputs
[
1
])
==
1.0
):
return
[
T
.
fill
(
node
.
inputs
[
1
],
node
.
inputs
[
0
])]
if
N
.
all
(
local_mul_canonizer
.
get_constant
(
node
.
inputs
[
1
])
==
0.0
):
#extra fills here are to make sure the size of the output stays constant.
return
[
T
.
fill
(
node
.
inputs
[
0
],
T
.
fill
(
node
.
inputs
[
1
],
1.0
))]
else
:
return
False
register_canonicalize
(
local_pow_canonicalize
)
@gof.local_optimizer
([
T
.
pow
])
def
local_pow_specialize
(
node
):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if
node
.
op
==
T
.
pow
:
#the idea here is that we have pow(x, y)
xsym
=
node
.
inputs
[
0
]
ysym
=
node
.
inputs
[
1
]
y
=
local_mul_canonizer
.
get_constant
(
ysym
)
if
(
y
is
not
None
)
\
and
encompasses_broadcastable
(
xsym
.
type
.
broadcastable
,
ysym
.
type
.
broadcastable
):
if
N
.
all
(
y
==
2.0
):
return
[
T
.
sqr
(
xsym
)]
if
N
.
all
(
y
==
1.0
):
return
[
xsym
]
if
N
.
all
(
y
==
0.0
):
return
[
T
.
fill
(
xsym
,
1.0
)]
if
N
.
all
(
y
==
0.5
):
return
[
T
.
sqrt
(
xsym
)]
if
N
.
all
(
y
==
-
0.5
):
return
[
T
.
inv
(
T
.
sqrt
(
xsym
))]
if
N
.
all
(
y
==
-
1.0
):
return
[
T
.
inv
(
xsym
)]
if
N
.
all
(
y
==
-
2.0
):
return
[
T
.
inv
(
T
.
sqr
(
xsym
))]
else
:
return
False
register_specialize
(
local_pow_specialize
)
if
0
:
#TODO: replace this with a c version of any InplaceDimShuffle
class
_TransposeInplace
(
T
.
Op
):
view_map
=
{
0
:
[
0
]}
def
make_node
(
self
,
input
):
return
T
.
Apply
(
self
,
[
input
],
[
T
.
tensor
(
dtype
=
input
.
type
.
dtype
,
broadcastable
=
reversed
(
input
.
type
.
broadcastable
))])
def
perform
(
self
,
node
,
(
x
,
),
(
z
,
)):
z
[
0
]
=
x
.
T
def
c_code
(
self
,
node
,
name
,
(
x
,
),
(
z
,
),
sub
):
return
"""
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(
%(x)
s, NULL);
if (
%(z)
s) {
Py_XDECREF(
%(z)
s);
}
%(z)
s = transposed;
"""
%
locals
()
def
__str__
(
self
):
return
"_TransposeInplace"
_transpose_inplace
=
_TransposeInplace
()
@gof.local_optimizer
([
T
.
DimShuffle
([
False
,
False
],[
1
,
0
],
inplace
=
True
)])
def
local_dimshuffle_transposeinplace
(
node
):
if
node
.
op
==
T
.
DimShuffle
([
False
,
False
],[
1
,
0
],
inplace
=
True
):
return
[
_transpose_inplace
(
node
.
inputs
[
0
])]
return
False
register_specialize
(
local_dimshuffle_transposeinplace
)
register_canonicalize
(
local_mul_canonizer
,
name
=
'local_mul_canonizer'
)
...
...
@@ -724,6 +813,248 @@ def constant_folding(node):
register_canonicalize
(
constant_folding
)
#################
# BLAS-related
#################
import
blas
class
_Dot22
(
gof
.
Op
):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def
make_node
(
self
,
x
,
y
):
assert
x
.
type
in
T
.
float_matrix_types
#makes sure x is a matrix
assert
y
.
type
==
x
.
type
#makes sure y is a matrix
bz
=
[
x
.
type
.
broadcastable
[
0
],
y
.
type
.
broadcastable
[
1
]]
outputs
=
[
T
.
tensor
(
x
.
type
.
dtype
,
bz
)]
return
gof
.
Apply
(
self
,
[
x
,
y
],
outputs
)
def
perform
(
self
,
node
,
(
x
,
y
),
(
z
,
)):
try
:
z
[
0
]
=
numpy
.
asarray
(
numpy
.
dot
(
x
,
y
))
except
ValueError
,
e
:
# The error raised by numpy has no shape information, we mean to add that
e
.
args
=
e
.
args
+
(
x
.
shape
,
y
.
shape
)
raise
def
__str__
(
self
):
return
"_dot22"
def
c_support_code
(
self
):
#return blas.cblas_header_text()
mod_str
=
"""
#ifndef MOD
#define MOD
%
#endif
"""
return
blas
.
blas_proto
()
+
mod_str
def
c_headers
(
self
):
return
[
'<iostream>'
]
def
c_libraries
(
self
):
return
blas
.
ldflags
()
def
c_code
(
self
,
node
,
name
,
(
_x
,
_y
),
(
_z
,
),
sub
):
return
"""
int unit = 0;
int type_num =
%(_x)
s->descr->type_num;
int type_size =
%(_x)
s->descr->elsize; // in bytes
npy_intp* Nx =
%(_x)
s->dimensions;
npy_intp* Ny =
%(_y)
s->dimensions;
npy_intp* Nz = 0; //
%(_z)
s->dimensions;
npy_intp* Sx =
%(_x)
s->strides;
npy_intp* Sy =
%(_y)
s->strides;
npy_intp* Sz = 0;//
%(_z)
s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if ((NULL ==
%(_z)
s)
|| (
%(_z)
s->dimensions[0] !=
%(_x)
s->dimensions[0])
|| (
%(_z)
s->dimensions[1] !=
%(_y)
s->dimensions[1]))
{
if (NULL !=
%(_z)
s) Py_XDECREF(
%(_z)
s);
npy_intp dims[2];
dims[0] =
%(_x)
s->dimensions[0];
dims[1] =
%(_y)
s->dimensions[1];
%(_z)
s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_
%(_x)
s);
if(!
%(_z)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)
s
}
}
Nz =
%(_z)
s->dimensions;
Sz =
%(_z)
s->strides;
if (
%(_x)
s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2");
%(fail)
s;}
if (
%(_y)
s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2");
%(fail)
s;}
if (
%(_z)
s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2");
%(fail)
s;}
if ((
%(_x)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_x)
s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float");
%(fail)
s;}
if ((
%(_y)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_y)
s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float");
%(fail)
s;}
if ((
%(_z)
s->descr->type_num != PyArray_DOUBLE)
&& (
%(_z)
s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float");
%(fail)
s;}
if ((
%(_x)
s->descr->type_num !=
%(_y)
s->descr->type_num)
||(
%(_x)
s->descr->type_num !=
%(_z)
s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same");
%(fail)
s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)
s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size");
%(fail)
s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
float a = 1.0;
float b = 0.0;
float* x = (float*)PyArray_DATA(
%(_x)
s);
float* y = (float*)PyArray_DATA(
%(_y)
s);
float* z = (float*)PyArray_DATA(
%(_z)
s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '
\\
n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride");
%(fail)
s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
double a = 1.0;
double b = 0.0;
double* x = (double*)PyArray_DATA(
%(_x)
s);
double* y = (double*)PyArray_DATA(
%(_y)
s);
double* z = (double*)PyArray_DATA(
%(_z)
s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '
\\
n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride");
%(fail)
s;
};
#undef REAL
}
break;
}
"""
%
dict
(
locals
(),
**
sub
)
_dot22
=
_Dot22
()
@gof.local_optimizer
([
T
.
dot
])
def
local_dot_to_dot22
(
node
):
if
node
.
op
==
T
.
dot
:
return
[
_dot22
(
*
node
.
inputs
)]
else
:
return
False
register_specialize
(
local_dot_to_dot22
)
@gof.local_optimizer
([
T
.
sub
])
def
local_sub_to_gemm
(
node
):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
"""
if
node
.
op
==
T
.
sub
:
subleft
,
subright
=
node
.
inputs
#EXPRESSION: subleft - subright
if
subright
.
owner
and
(
subright
.
owner
.
op
==
_dot22
):
dotleft
,
dotright
=
subright
.
owner
.
inputs
return
[
T
.
gemm
(
subleft
,
-
1.0
,
dotleft
,
dotright
,
1.0
)]
if
subright
.
owner
and
(
subright
.
owner
.
op
==
T
.
mul
):
mulleft
,
mulright
=
subright
.
owner
.
inputs
#EXPRESSION: subleft - (mulleft * mulright)
#TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const
=
local_mul_canonizer
.
get_constant
(
mulleft
)
if
mulleft_const
is
not
None
and
mulleft_const
.
size
==
1
:
mulleft_const
=
mulleft_const
.
flatten
()[
0
]
#EXPRESSION: subleft - (mulleft_const * ?)
if
mulright
.
owner
and
(
mulright
.
owner
.
op
==
T
.
add
):
#EXPRESSION: subleft - (mulleft_const * (? + ?))
addleft
,
addright
=
mulright
.
owner
.
inputs
if
addright
.
owner
and
addright
.
owner
.
op
==
T
.
DimShuffle
([
False
,
False
],
[
1
,
0
]):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
raise
NotImplementedError
()
if
addright
.
owner
and
addright
.
owner
.
op
==
T
.
DimShuffle
([
False
,
False
],
[
1
,
0
],
inplace
=
True
):
#EXPRESSION: subleft - (mulleft_const * (? + ?.T))
transposed
=
addright
.
owner
.
inputs
[
0
]
if
transposed
.
owner
and
transposed
.
owner
.
op
==
_dot22
:
x
,
y
=
transposed
.
owner
.
inputs
#EXPRESSION: subleft - (mulleft_const * (addleft + dot(x, y).T))
if
addleft
.
owner
and
addleft
.
owner
.
op
==
_dot22
:
u
,
v
=
addleft
.
owner
.
inputs
#EXPRESSION: subleft - (mulleft_const * (dot(u,v) + dot(x, y).T))
return
[
T
.
gemm
(
T
.
gemm
(
subleft
,
-
mulleft_const
,
y
.
T
,
x
.
T
,
1.0
),
-
mulleft_const
,
u
,
v
,
1.0
)]
if
mulright
.
owner
and
(
mulright
.
owner
.
op
==
_dot22
):
dotleft
,
dotright
=
mulright
.
owner
.
inputs
#EXPRESSION: subleft - (mulleft_const * dot(dotleft, dotright))
return
[
T
.
gemm
(
subleft
,
-
mulleft_const
,
dotleft
,
dotright
,
1.0
)]
mulright_const
=
local_mul_canonizer
.
get_constant
(
mulright
)
if
mulright_const
is
not
None
and
mulright_const
.
size
==
1
:
mulright_const
=
mulright_const
.
flatten
()[
0
]
#EXPRESSION: subleft - (? * mulright_const)
if
mulleft
.
owner
and
(
mulleft
.
owner
.
op
==
_dot22
):
dotleft
,
dotright
=
mulleft
.
owner
.
inputs
#EXPRESSION: subleft - (dot(dotleft, dotright) * mulright_const)
return
[
T
.
gemm
(
subleft
,
-
mulright_const
,
dotleft
,
dotright
,
1.0
)]
return
False
register_specialize
(
local_sub_to_gemm
)
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
7bac82c1
...
...
@@ -662,56 +662,6 @@ class T_max_and_argmax(unittest.TestCase):
self
.
failUnless
(
i
.
shape
==
(
2
,
3
))
class
T_transpose
(
unittest
.
TestCase
):
def
test0
(
self
):
n
=
as_tensor
(
numpy
.
ones
(()))
t
=
transpose
(
n
)
self
.
failUnless
(
t
.
owner
.
op
==
inplace
.
transpose_inplace
)
f
=
function
([
n
],
t
)
tval
=
f
(
n
.
data
)
self
.
failUnless
(
tval
.
shape
==
n
.
data
.
shape
)
#test aliasing
tval
+=
55.0
self
.
failUnless
(
n
.
data
==
1.0
)
def
test1
(
self
):
n
=
as_tensor
(
numpy
.
ones
(
5
))
t
=
transpose
(
n
)
self
.
failUnless
(
t
.
owner
.
op
==
inplace
.
transpose_inplace
)
f
=
function
([
n
],
t
)
tval
=
f
(
n
.
data
)
self
.
failUnless
(
tval
.
shape
==
n
.
data
.
shape
)
#test aliasing
tval
+=
55.0
self
.
failUnless
(
n
.
data
[
0
]
==
1.0
)
def
test2
(
self
):
n
=
as_tensor
(
numpy
.
ones
((
5
,
3
)))
t
=
transpose
(
n
)
self
.
failUnless
(
t
.
owner
.
op
==
inplace
.
transpose_inplace
)
f
=
function
([
n
],
t
)
tval
=
f
(
n
.
data
)
self
.
failUnless
(
tval
.
shape
==
(
3
,
5
))
#test aliasing
tval
+=
55.0
self
.
failUnless
(
n
.
data
[
0
,
0
]
==
1.0
)
def
test3
(
self
):
"""Test transpose of tensor, inplace version"""
n
=
as_tensor
(
numpy
.
ones
((
5
,
3
,
2
)))
t
=
inplace
.
transpose_inplace
(
n
)
self
.
failUnless
(
t
.
owner
.
op
==
inplace
.
transpose_inplace
)
f
=
function
([
n
],
t
)
tval
=
f
(
n
.
data
)
self
.
failUnless
(
tval
.
shape
==
(
2
,
3
,
5
))
#test aliasing
tval
+=
55.0
self
.
failUnless
(
n
.
data
[
0
,
0
,
0
]
==
56.0
)
def
test_grad
(
self
):
verify_grad
(
self
,
inplace
.
transpose_inplace
,
[
numpy
.
random
.
rand
(
2
,
3
)])
verify_grad
(
self
,
inplace
.
transpose_inplace
,
[
numpy
.
ones
(
3
)])
class
T_subtensor
(
unittest
.
TestCase
):
def
setUp
(
self
):
Subtensor
.
debug
=
False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论