Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b6b2c608
提交
b6b2c608
authored
11月 26, 2008
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code in a mess, but gemm-optimization works on more systematic test cases…
code in a mess, but gemm-optimization works on more systematic test cases including josephs NAACL graph
上级
43291f46
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
452 行增加
和
573 行删除
+452
-573
mode.py
theano/compile/mode.py
+10
-1
__init__.py
theano/gof/__init__.py
+8
-9
opt.py
theano/gof/opt.py
+32
-13
optdb.py
theano/gof/optdb.py
+7
-2
blas.py
theano/tensor/blas.py
+182
-138
elemwise.py
theano/tensor/elemwise.py
+30
-13
opt.py
theano/tensor/opt.py
+9
-323
test_blas.py
theano/tensor/tests/test_blas.py
+140
-46
test_joseph.py
theano/tensor/tests/test_joseph.py
+34
-28
没有找到文件。
theano/compile/mode.py
浏览文件 @
b6b2c608
...
@@ -63,11 +63,20 @@ def register_optimizer(name, opt):
...
@@ -63,11 +63,20 @@ def register_optimizer(name, opt):
raise
ValueError
(
'Optimizer name already taken:
%
s'
%
name
)
raise
ValueError
(
'Optimizer name already taken:
%
s'
%
name
)
predefined_optimizers
[
name
]
=
opt
predefined_optimizers
[
name
]
=
opt
class
AddDestroyHandler
(
gof
.
Optimizer
):
def
apply
(
self
,
env
):
pass
def
add_requirements
(
self
,
env
):
super
(
AddDestroyHandler
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
gof
.
DestroyHandler
())
optdb
=
gof
.
SequenceDB
()
optdb
=
gof
.
SequenceDB
()
optdb
.
register
(
'merge1'
,
gof
.
MergeOptimizer
(),
0
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'merge1'
,
gof
.
MergeOptimizer
(),
0
,
'fast_run'
,
'fast_compile'
)
optdb
.
register
(
'canonicalize'
,
gof
.
EquilibriumDB
(),
1
,
'fast_run'
)
optdb
.
register
(
'canonicalize'
,
gof
.
EquilibriumDB
(),
1
,
'fast_run'
)
optdb
.
register
(
'specialize'
,
gof
.
EquilibriumDB
(),
2
,
'fast_run'
)
optdb
.
register
(
'specialize'
,
gof
.
EquilibriumDB
(),
2
,
'fast_run'
)
optdb
.
register
(
'merge2'
,
gof
.
EquilibriumDB
(),
100
,
'fast_run'
)
optdb
.
register
(
'merge2'
,
gof
.
EquilibriumDB
(),
49
,
'fast_run'
)
optdb
.
register
(
'add_destroy_handler'
,
AddDestroyHandler
(),
49.5
,
'fast_run'
,
'inplace'
)
optdb
.
register
(
'merge3'
,
gof
.
EquilibriumDB
(),
100
,
'fast_run'
)
class
Mode
(
object
):
class
Mode
(
object
):
...
...
theano/gof/__init__.py
浏览文件 @
b6b2c608
...
@@ -20,15 +20,14 @@ from link import \
...
@@ -20,15 +20,14 @@ from link import \
from
op
import
\
from
op
import
\
Op
Op
from
opt
import
\
from
opt
import
(
Optimizer
,
optimizer
,
SeqOptimizer
,
Optimizer
,
optimizer
,
SeqOptimizer
,
\
MergeOptimizer
,
MergeOptMerge
,
MergeOptimizer
,
MergeOptMerge
,
\
LocalOptimizer
,
local_optimizer
,
LocalOptGroup
,
LocalOptimizer
,
local_optimizer
,
LocalOptGroup
,
\
OpSub
,
OpRemove
,
PatternSub
,
OpSub
,
OpRemove
,
PatternSub
,
\
NavigatorOptimizer
,
TopoOptimizer
,
EquilibriumOptimizer
,
NavigatorOptimizer
,
TopoOptimizer
,
EquilibriumOptimizer
,
\
keep_going
,
warn
,
keep_going
,
warn
,
\
InplaceOptimizer
,
PureThenInplaceOptimizer
,
InplaceOptimizer
,
PureThenInplaceOptimizer
OpKeyOptimizer
)
#LocalOpKeyOptGroup, OpKeyOptimizer
from
optdb
import
\
from
optdb
import
\
DB
,
Query
,
\
DB
,
Query
,
\
...
...
theano/gof/opt.py
浏览文件 @
b6b2c608
...
@@ -265,6 +265,11 @@ class LocalOptimizer(object):
...
@@ -265,6 +265,11 @@ class LocalOptimizer(object):
raise
utils
.
AbstractFunctionError
()
raise
utils
.
AbstractFunctionError
()
def
add_requirements
(
self
,
env
):
"""If this local optimization wants to add some requirements to the env,
This is the place to do it."""
env
.
extend
(
toolbox
.
ReplaceValidate
())
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
"""WRITEME"""
...
@@ -273,8 +278,6 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
...
@@ -273,8 +278,6 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
self
.
_tracks
=
tracks
self
.
_tracks
=
tracks
def
tracks
(
self
):
def
tracks
(
self
):
return
self
.
_tracks
return
self
.
_tracks
def
add_requirements
(
self
,
env
):
env
.
extend
(
toolbox
.
ReplaceValidate
())
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'name'
,
'<FromFunctionLocalOptimizer instance>'
)
return
getattr
(
self
,
'name'
,
'<FromFunctionLocalOptimizer instance>'
)
...
@@ -551,7 +554,7 @@ class NavigatorOptimizer(Optimizer):
...
@@ -551,7 +554,7 @@ class NavigatorOptimizer(Optimizer):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too)
.
:param ignore_newtrees:
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
...
@@ -617,6 +620,24 @@ class NavigatorOptimizer(Optimizer):
...
@@ -617,6 +620,24 @@ class NavigatorOptimizer(Optimizer):
env
.
remove_feature
(
u
)
env
.
remove_feature
(
u
)
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
return either False or a list of Results that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is successful, and this
function returns True.
If there are no replacement candidates or the env rejects the replacements, this
function returns False.
:param env: an Env
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for how to compute
node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
"""
lopt
=
lopt
or
self
.
local_opt
lopt
=
lopt
or
self
.
local_opt
try
:
try
:
replacements
=
lopt
.
transform
(
node
)
replacements
=
lopt
.
transform
(
node
)
...
@@ -633,23 +654,21 @@ class NavigatorOptimizer(Optimizer):
...
@@ -633,23 +654,21 @@ class NavigatorOptimizer(Optimizer):
env
.
replace_all_validate
(
repl_pairs
)
env
.
replace_all_validate
(
repl_pairs
)
return
True
return
True
except
Exception
,
e
:
except
Exception
,
e
:
# This means the replacements were rejected by the env.
#
# This is not supposed to happen. The default failure_callback will print a
# traceback as a warning.
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
repl_pairs
)
self
.
failure_callback
(
e
,
self
,
repl_pairs
)
#DEBUG DONT PUSH
#print lopt
#print dir(lopt)
#raise
#END
return
False
return
False
else
:
else
:
raise
raise
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
super
(
NavigatorOptimizer
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
toolbox
.
ReplaceValidate
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
if
self
.
local_opt
:
self
.
local_opt
.
add_requirements
(
env
)
class
TopoOptimizer
(
NavigatorOptimizer
):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""WRITEME"""
...
@@ -722,7 +741,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -722,7 +741,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
- NodeFinder
- NodeFinder
- ReplaceValidate
- ReplaceValidate
"""
"""
NavigatorOptimizer
.
add_requirements
(
self
,
env
)
super
(
OpKeyOptimizer
,
self
)
.
add_requirements
(
env
)
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
NodeFinder
())
...
...
theano/gof/optdb.py
浏览文件 @
b6b2c608
...
@@ -13,6 +13,8 @@ class DB(object):
...
@@ -13,6 +13,8 @@ class DB(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
__db__
=
defaultdict
(
set
)
self
.
__db__
=
defaultdict
(
set
)
self
.
_names
=
set
()
self
.
_names
=
set
()
self
.
name
=
None
#will be reset by register
#(via obj.name by the thing doing the registering)
def
register
(
self
,
name
,
obj
,
*
tags
):
def
register
(
self
,
name
,
obj
,
*
tags
):
# N.B. obj is not an instance of class Optimizer.
# N.B. obj is not an instance of class Optimizer.
...
@@ -21,6 +23,8 @@ class DB(object):
...
@@ -21,6 +23,8 @@ class DB(object):
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
,
opt
.
LocalOptimizer
)):
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
,
opt
.
LocalOptimizer
)):
raise
Exception
(
'wtf'
,
obj
)
raise
Exception
(
'wtf'
,
obj
)
if
self
.
name
is
not
None
:
tags
=
tags
+
(
self
.
name
,)
obj
.
name
=
name
obj
.
name
=
name
if
name
in
self
.
__db__
:
if
name
in
self
.
__db__
:
raise
ValueError
(
'The name of the object cannot be an existing tag or the name of another existing object.'
,
obj
,
name
)
raise
ValueError
(
'The name of the object cannot be an existing tag or the name of another existing object.'
,
obj
,
name
)
...
@@ -118,9 +122,10 @@ class EquilibriumDB(DB):
...
@@ -118,9 +122,10 @@ class EquilibriumDB(DB):
class
SequenceDB
(
DB
):
class
SequenceDB
(
DB
):
def
__init__
(
self
):
def
__init__
(
self
,
failure_callback
=
opt
.
warn
):
super
(
SequenceDB
,
self
)
.
__init__
()
super
(
SequenceDB
,
self
)
.
__init__
()
self
.
__priority__
=
{}
self
.
__priority__
=
{}
self
.
failure_callback
=
failure_callback
def
register
(
self
,
name
,
obj
,
priority
,
*
tags
):
def
register
(
self
,
name
,
obj
,
priority
,
*
tags
):
super
(
SequenceDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
super
(
SequenceDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
...
@@ -130,6 +135,6 @@ class SequenceDB(DB):
...
@@ -130,6 +135,6 @@ class SequenceDB(DB):
opts
=
super
(
SequenceDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
opts
=
super
(
SequenceDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
opts
=
list
(
opts
)
opts
=
list
(
opts
)
opts
.
sort
(
key
=
lambda
obj
:
self
.
__priority__
[
obj
.
name
])
opts
.
sort
(
key
=
lambda
obj
:
self
.
__priority__
[
obj
.
name
])
return
opt
.
SeqOptimizer
(
opts
,
failure_callback
=
opt
.
warn
)
return
opt
.
SeqOptimizer
(
opts
,
failure_callback
=
self
.
failure_callback
)
theano/tensor/blas.py
浏览文件 @
b6b2c608
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import
os
,
sys
import
os
,
sys
,
traceback
import
numpy
import
numpy
from
..gof
import
(
utils
,
Op
,
Apply
,
view_roots
,
PatternSub
,
from
..gof
import
(
utils
,
Op
,
Apply
,
view_roots
,
PatternSub
,
DestroyHandler
,
InplaceOptimizer
,
SeqOptimizer
,
warn
,
local_optimizer
)
SeqOptimizer
,
warn
,
local_optimizer
,
LocalOptimizer
,
OpKeyOptimizer
,
InconsistencyError
)
from
..printing
import
pprint
,
FunctionPrinter
from
..printing
import
pprint
,
FunctionPrinter
from
.opt
import
register_specialize
,
out2in
,
insert_inplace_optimizer
from
.opt
import
register_specialize
,
out2in
,
insert_inplace_optimizer
import
basic
as
T
import
basic
as
T
from
..tensor
import
as_tensor
#NB: this clobbers the builtin 'compile' symbol
#NB: this clobbers the builtin 'compile' symbol
from
..
import
compile
#to register the optimizer built by this file
from
..
import
compile
#to register the optimizer built by this file
from
.blas_headers
import
cblas_header_text
,
blas_header_text
from
.blas_headers
import
cblas_header_text
,
blas_header_text
JOSEPHS_BUG_SOLVED
=
False
@utils.memoize
@utils.memoize
def
ldflags
():
def
ldflags
():
"""Return a list of libraries against which an Op's object file should be
"""Return a list of libraries against which an Op's object file should be
...
@@ -270,7 +267,7 @@ class Gemm(GemmRelated):
...
@@ -270,7 +267,7 @@ class Gemm(GemmRelated):
E_z_uniq
=
'argument z aliased to x or y'
E_z_uniq
=
'argument z aliased to x or y'
destroy_map
=
{
0
:
[
0
]}
destroy_map
=
{
0
:
[
0
]}
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
inputs
=
map
(
as_tensor
,
inputs
)
inputs
=
map
(
T
.
as_tensor
,
inputs
)
if
len
(
inputs
)
!=
5
:
if
len
(
inputs
)
!=
5
:
raise
TypeError
(
"Wrong number of inputs for
%
s (expected 5, got
%
s)"
%
(
self
,
len
(
inputs
)))
raise
TypeError
(
"Wrong number of inputs for
%
s (expected 5, got
%
s)"
%
(
self
,
len
(
inputs
)))
z
,
a
,
x
,
y
,
b
=
inputs
z
,
a
,
x
,
y
,
b
=
inputs
...
@@ -348,87 +345,110 @@ class Gemm(GemmRelated):
...
@@ -348,87 +345,110 @@ class Gemm(GemmRelated):
#undef REAL
#undef REAL
"""
"""
def
c_code
(
self
,
node
,
name
,
(
_z
,
_a
,
_x
,
_y
,
_b
),
(
_zout
,
),
sub
):
def
c_code
(
self
,
node
,
name
,
(
_z
,
_a
,
_x
,
_y
,
_b
),
(
_zout
,
),
sub
):
#DEBUG
full_code
=
self
.
build_gemm_call
()
%
dict
(
locals
(),
**
sub
)
full_code
=
self
.
build_gemm_call
()
%
dict
(
locals
(),
**
sub
)
return
full_code
return
full_code
gemm
=
Gemm
()
gemm
=
Gemm
()
pprint
.
assign
(
gemm
,
FunctionPrinter
(
'gemm'
))
pprint
.
assign
(
gemm
,
FunctionPrinter
(
'gemm'
))
class
Dot22
(
GemmRelated
):
def
res_is_a
(
node
,
op
,
maxclients
=
None
):
"""Compute a matrix-matrix product.
return
node
.
owner
\
This is a specialization of the more general Dot()
and
node
.
owner
.
op
==
op
\
and
(
len
(
node
.
clients
)
<=
maxclients
if
maxclients
is
not
None
else
True
)
class
GemmLocalOptimizer
(
LocalOptimizer
):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
"""
def
make_node
(
self
,
x
,
y
):
assert
_is_real_matrix
(
x
)
assert
y
.
type
==
x
.
type
#makes sure y is a matrix
bz
=
[
False
,
False
]
outputs
=
[
T
.
tensor
(
x
.
type
.
dtype
,
bz
)]
return
Apply
(
self
,
[
x
,
y
],
outputs
)
def
perform
(
self
,
node
,
(
x
,
y
),
(
z
,
)):
def
__init__
(
self
):
try
:
super
(
LocalOptimizer
,
self
)
.
__init__
()
z
[
0
]
=
numpy
.
asarray
(
numpy
.
dot
(
x
,
y
))
except
ValueError
,
e
:
# The error raised by numpy has no shape information, we mean to add that
e
.
args
=
e
.
args
+
(
x
.
shape
,
y
.
shape
)
raise
def
__str__
(
self
):
return
"_dot22"
setup_z_Nz_Sz
=
"""
def
op_key
(
self
):
if ((NULL ==
%(_z)
s)
return
[
T
.
add
,
T
.
sub
]
|| (
%(_z)
s->dimensions[0] !=
%(_x)
s->dimensions[0])
|| (
%(_z)
s->dimensions[1] !=
%(_y)
s->dimensions[1]))
{
if (NULL !=
%(_z)
s) Py_XDECREF(
%(_z)
s);
npy_intp dims[2];
dims[0] =
%(_x)
s->dimensions[0];
dims[1] =
%(_y)
s->dimensions[1];
%(_z)
s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_
%(_x)
s);
if(!
%(_z)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)
s
}
}
Nz =
%(_z)
s->dimensions;
Sz =
%(_z)
s->strides;
"""
def
add_requirements
(
self
,
env
):
check_ab_double_or_float
=
""
super
(
GemmLocalOptimizer
,
self
)
.
add_requirements
(
env
)
case_float_ab_constants
=
"""
env
.
extend
(
DestroyHandler
())
float a = 1.0;
float b = 0.0;
"""
case_double_ab_constants
=
"""
double a = 1.0;
double b = 0.0;
"""
def
c_code
(
self
,
node
,
name
,
(
_x
,
_y
),
(
_z
,
),
sub
):
full_code
=
self
.
build_gemm_call
()
%
dict
(
locals
(),
**
sub
)
return
full_code
_dot22
=
Dot22
()
@local_optimizer
([
T
.
dot
])
def
transform
(
self
,
node
):
def
local_dot_to_dot22
(
node
):
_as_scalar
,
_is_real_matrix
,
_as_isolated_scalar_times_matrix
,
beta_L_plus_alpha_M
\
if
node
.
op
==
T
.
dot
:
=
(
GemmLocalOptimizer
.
_as_scalar
,
x
,
y
=
node
.
inputs
GemmLocalOptimizer
.
_is_real_matrix
,
if
_is_real_matrix
(
x
)
and
y
.
type
==
x
.
type
:
GemmLocalOptimizer
.
_as_isolated_scalar_times_matrix
,
return
[
_dot22
(
*
node
.
inputs
)]
GemmLocalOptimizer
.
beta_L_plus_alpha_M
)
if
node
.
op
==
T
.
sub
:
L
,
R
=
node
.
inputs
if
not
_is_real_matrix
(
L
):
return
False
if
not
_is_real_matrix
(
R
):
return
False
tmp
=
_as_isolated_scalar_times_matrix
(
L
)
try
:
sL
,
mL
=
tmp
except
:
sL
,
mL
=
1.0
,
L
tmp
=
_as_isolated_scalar_times_matrix
(
R
)
try
:
sR
,
mR
=
tmp
except
:
sR
,
mR
=
1.0
,
R
rval
=
beta_L_plus_alpha_M
(
sL
,
mL
,
-
sR
,
mR
)
return
rval
if
node
.
op
==
T
.
add
:
sM_list
=
[]
other_inputs
=
[]
for
input
in
node
.
inputs
:
tmp
=
_as_isolated_scalar_times_matrix
(
input
)
if
tmp
:
sM_list
.
append
(
tmp
)
elif
_is_real_matrix
(
input
):
sM_list
.
append
((
1.0
,
input
))
else
:
other_inputs
.
append
(
input
)
if
len
(
sM_list
)
==
2
:
(
sL
,
mL
),
(
sR
,
mR
)
=
sM_list
gemm_of_sM_list
=
beta_L_plus_alpha_M
(
sL
,
mL
,
sR
,
mR
)
if
gemm_of_sM_list
:
#we turned the two candidates into a gemm
# now we have to add the other_inputs and return the replacement graph
if
other_inputs
:
return
[
T
.
add
(
*
(
other_inputs
+
gemm_of_sM_list
))]
else
:
return
gemm_of_sM_list
else
:
else
:
for
i
in
xrange
(
len
(
sM_list
)
-
1
):
for
j
in
xrange
(
i
+
1
,
len
(
sM_list
)):
sL
,
mL
=
sM_list
[
i
]
sR
,
mR
=
sM_list
[
j
]
gemm_of_sM_list
=
beta_L_plus_alpha_M
(
sL
,
mL
,
sR
,
mR
)
if
gemm_of_sM_list
:
assert
len
(
gemm_of_sM_list
)
==
1
inputs_without_ij
=
\
[
input
for
k
,
input
in
enumerate
(
node
.
inputs
)
if
k
not
in
(
i
,
j
)]
return
[
T
.
add
(
*
(
inputs_without_ij
+
gemm_of_sM_list
+
other_inputs
))]
return
False
return
False
if
JOSEPHS_BUG_SOLVED
:
register_specialize
(
local_dot_to_dot22
)
def
_is_a
(
node
,
op
,
maxclients
=
None
):
@staticmethod
return
node
.
owner
\
def
failure_callback
(
exc
,
nav
,
repl_pairs
):
and
node
.
owner
.
op
==
op
\
"""WRITEME"""
and
len
(
node
.
clients
)
<=
maxclients
if
maxclients
is
not
None
else
True
if
not
isinstance
(
exc
,
InconsistencyError
):
traceback
.
print_exc
()
else
:
print
'GEMM caused cycle, forget it.'
def
_as_scalar
(
res
):
@staticmethod
def
_as_scalar
(
res
):
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
if
res
.
owner
and
isinstance
(
res
.
owner
.
op
,
T
.
DimShuffle
):
if
res
.
owner
and
isinstance
(
res
.
owner
.
op
,
T
.
DimShuffle
):
return
_as_scalar
(
res
.
owner
.
inputs
[
0
])
return
GemmLocalOptimizer
.
_as_scalar
(
res
.
owner
.
inputs
[
0
])
elif
res
.
type
in
T
.
float_scalar_types
:
elif
res
.
type
in
T
.
float_scalar_types
:
return
res
return
res
elif
isinstance
(
res
,
T
.
Constant
)
and
res
.
data
.
size
==
1
:
elif
isinstance
(
res
,
T
.
Constant
)
and
res
.
data
.
size
==
1
:
...
@@ -436,13 +456,20 @@ def _as_scalar(res):
...
@@ -436,13 +456,20 @@ def _as_scalar(res):
else
:
else
:
return
None
return
None
def
_is_real_matrix
(
res
):
@staticmethod
def
_is_real_matrix
(
res
):
return
res
.
type
in
T
.
float_matrix_types
\
return
res
.
type
in
T
.
float_matrix_types
\
and
res
.
broadcastable
[
0
]
==
False
\
and
res
.
broadcastable
[
0
]
==
False
\
and
res
.
broadcastable
[
1
]
==
False
#cope with tuple vs. list
and
res
.
broadcastable
[
1
]
==
False
#cope with tuple vs. list
def
_as_isolated_scalar_times_matrix
(
res
):
@staticmethod
if
_is_a
(
res
,
T
.
mul
,
1
):
def
_as_isolated_scalar_times_matrix
(
res
):
_as_scalar
,
_is_real_matrix
,
_as_isolated_scalar_times_matrix
,
beta_L_plus_alpha_M
\
=
(
GemmLocalOptimizer
.
_as_scalar
,
GemmLocalOptimizer
.
_is_real_matrix
,
GemmLocalOptimizer
.
_as_isolated_scalar_times_matrix
,
GemmLocalOptimizer
.
beta_L_plus_alpha_M
)
if
res_is_a
(
res
,
T
.
mul
,
1
):
if
len
(
res
.
owner
.
inputs
)
==
2
:
if
len
(
res
.
owner
.
inputs
)
==
2
:
L
,
R
=
res
.
owner
.
inputs
L
,
R
=
res
.
owner
.
inputs
sL
=
_as_scalar
(
L
)
sL
=
_as_scalar
(
L
)
...
@@ -466,105 +493,122 @@ def _as_isolated_scalar_times_matrix(res):
...
@@ -466,105 +493,122 @@ def _as_isolated_scalar_times_matrix(res):
rval
=
(
T
.
mul
(
*
scalars
),
matrices
[
0
])
rval
=
(
T
.
mul
(
*
scalars
),
matrices
[
0
])
return
rval
return
rval
@staticmethod
def
beta_L_plus_alpha_M
(
beta
,
L
,
alpha
,
M
,
recurse_flip
=
True
):
def
beta_L_plus_alpha_M
(
beta
,
L
,
alpha
,
M
,
recurse_flip
=
True
):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
#EXPRESSION: (beta * L) + (alpha * M)
if
_is_a
(
M
,
_dot22
,
1
):
if
True
:
if
res_is_a
(
L
,
T
.
sqrt
):
print
'CLIENTS OF L'
,
L
,
L
.
clients
if
res_is_a
(
M
,
_dot22
,
1
):
Ml
,
Mr
=
M
.
owner
.
inputs
Ml
,
Mr
=
M
.
owner
.
inputs
rval
=
[
gemm
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
rval
=
[
gemm
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
print
'GEMM 0'
,
rval
,
beta
,
L
,
alpha
,
M
return
rval
return
rval
if
_is_a
(
M
,
gemm
,
1
):
if
False
and
res
_is_a
(
M
,
gemm
,
1
):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G
,
a
,
u
,
v
,
b
=
M
.
owner
.
inputs
G
,
a
,
u
,
v
,
b
=
M
.
owner
.
inputs
#print 'GEMM', G, L
#print 'GEMM', G, L
if
_is_a
(
G
,
_dot22
,
1
):
if
res
_is_a
(
G
,
_dot22
,
1
):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x
,
y
=
G
.
owner
.
inputs
x
,
y
=
G
.
owner
.
inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#print 'GEMM 1', G, L
rval
=
[
gemm
(
gemm
(
L
,
alpha
*
b
,
x
,
y
,
beta
),
alpha
*
a
,
u
,
v
,
1.0
)]
rval
=
[
gemm
(
gemm
(
L
,
alpha
*
b
,
x
,
y
,
beta
),
alpha
*
a
,
u
,
v
,
1.0
)]
print
'GEMM 1'
,
rval
return
rval
return
rval
elif
G
is
L
:
if
(
G
is
L
)
:
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval
=
[
gemm
(
L
,
alpha
*
a
,
u
,
v
,
alpha
*
b
+
beta
)]
rval
=
[
gemm
(
L
,
alpha
*
a
,
u
,
v
,
alpha
*
b
+
beta
)]
#
print 'GEMM 2', rval
print
'GEMM 2'
,
rval
return
rval
return
rval
elif
1.0
!=
alpha
:
if
(
1.0
!=
alpha
)
:
#at the very least, move the alpha inside the gemm
#at the very least, move the alpha inside the gemm
rval
=
[
beta
*
L
+
gemm
(
G
,
alpha
*
a
,
u
,
v
,
alpha
*
b
)]
rval
=
[
beta
*
L
+
gemm
(
G
,
alpha
*
a
,
u
,
v
,
alpha
*
b
)]
#print 'GEMM 3', G, L
print
'GEMM 3'
,
rval
return
rval
return
rval
if
recurse_flip
:
if
recurse_flip
:
return
beta_L_plus_alpha_M
(
alpha
,
M
,
beta
,
L
,
recurse_flip
=
False
)
return
GemmLocalOptimizer
.
beta_L_plus_alpha_M
(
alpha
,
M
,
beta
,
L
,
recurse_flip
=
False
)
else
:
else
:
return
False
return
False
@local_optimizer
([
T
.
sub
])
#I think that three passes should suffice to catch all the GEMMs.
def
local_sub_to_gemm
(
node
):
# TODO: This could be an equilibriumOptmizer, but I don't know how to combine an OpKeyOptimizer and
if
node
.
op
==
T
.
sub
:
# an EquilibriumOptimizer.
L
,
R
=
node
.
inputs
compile
.
optdb
.
register
(
'inplace_gemm_0'
,
OpKeyOptimizer
(
GemmLocalOptimizer
(),
if
not
_is_real_matrix
(
L
):
failure_callback
=
GemmLocalOptimizer
.
failure_callback
),
70.00
,
'fast_run'
,
'inplace'
)
return
False
compile
.
optdb
.
register
(
'inplace_gemm_1'
,
OpKeyOptimizer
(
GemmLocalOptimizer
(),
if
not
_is_real_matrix
(
R
):
failure_callback
=
GemmLocalOptimizer
.
failure_callback
),
70.01
,
'fast_run'
,
'inplace'
)
return
False
compile
.
optdb
.
register
(
'inplace_gemm_2'
,
OpKeyOptimizer
(
GemmLocalOptimizer
(),
failure_callback
=
GemmLocalOptimizer
.
failure_callback
),
70.02
,
'fast_run'
,
'inplace'
)
tmp
=
_as_isolated_scalar_times_matrix
(
L
)
class
Dot22
(
GemmRelated
):
try
:
"""Compute a matrix-matrix product.
sL
,
mL
=
tmp
This is a specialization of the more general Dot()
except
:
"""
sL
,
mL
=
1.0
,
L
def
make_node
(
self
,
x
,
y
):
assert
GemmLocalOptimizer
.
_is_real_matrix
(
x
)
assert
y
.
type
==
x
.
type
#makes sure y is a matrix
bz
=
[
False
,
False
]
outputs
=
[
T
.
tensor
(
x
.
type
.
dtype
,
bz
)]
return
Apply
(
self
,
[
x
,
y
],
outputs
)
tmp
=
_as_isolated_scalar_times_matrix
(
R
)
def
perform
(
self
,
node
,
(
x
,
y
),
(
z
,
)):
try
:
try
:
sR
,
mR
=
tmp
z
[
0
]
=
numpy
.
asarray
(
numpy
.
dot
(
x
,
y
))
except
:
except
ValueError
,
e
:
sR
,
mR
=
1.0
,
R
# The error raised by numpy has no shape information, we mean to add that
rval
=
beta_L_plus_alpha_M
(
sL
,
mL
,
-
sR
,
mR
)
e
.
args
=
e
.
args
+
(
x
.
shape
,
y
.
shape
)
return
rval
raise
return
False
def
__str__
(
self
):
if
JOSEPHS_BUG_SOLVED
:
return
"_dot22"
register_specialize
(
local_sub_to_gemm
)
@local_optimizer
([
T
.
add
])
setup_z_Nz_Sz
=
"""
def
local_add_to_gemm
(
node
):
if ((NULL ==
%(_z)
s)
"""This is a massive beast for recognizing all the ways that a subtraction could be
|| (
%(_z)
s->dimensions[0] !=
%(_x)
s->dimensions[0])
replaced by a GEMM
|| (
%(_z)
s->dimensions[1] !=
%(_y)
s->dimensions[1]))
{
if (NULL !=
%(_z)
s) Py_XDECREF(
%(_z)
s);
npy_intp dims[2];
dims[0] =
%(_x)
s->dimensions[0];
dims[1] =
%(_y)
s->dimensions[1];
%(_z)
s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_
%(_x)
s);
if(!
%(_z)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)
s
}
}
Nz =
%(_z)
s->dimensions;
Sz =
%(_z)
s->strides;
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
"""
if
node
.
op
==
T
.
add
:
check_ab_double_or_float
=
""
sM_list
=
[]
case_float_ab_constants
=
"""
for
input
in
node
.
inputs
:
float a = 1.0;
tmp
=
_as_isolated_scalar_times_matrix
(
input
)
float b = 0.0;
if
tmp
:
"""
sM_list
.
append
(
tmp
)
case_double_ab_constants
=
"""
elif
_is_real_matrix
(
input
):
double a = 1.0;
sM_list
.
append
((
1.0
,
input
))
double b = 0.0;
"""
def
c_code
(
self
,
node
,
name
,
(
_x
,
_y
),
(
_z
,
),
sub
):
#DEBUG
full_code
=
self
.
build_gemm_call
()
%
dict
(
locals
(),
**
sub
)
return
full_code
_dot22
=
Dot22
()
if
len
(
sM_list
)
==
2
:
@local_optimizer
([
T
.
dot
])
sL
,
mL
=
sM_list
[
0
]
def
local_dot_to_dot22
(
node
):
sR
,
mR
=
sM_list
[
1
]
if
node
.
op
==
T
.
dot
:
return
beta_L_plus_alpha_M
(
sL
,
mL
,
sR
,
mR
)
x
,
y
=
node
.
inputs
if
GemmLocalOptimizer
.
_is_real_matrix
(
x
)
and
y
.
type
==
x
.
type
:
return
[
_dot22
(
*
node
.
inputs
)]
else
:
else
:
for
i
in
xrange
(
len
(
sM_list
)
-
1
):
for
j
in
xrange
(
i
+
1
,
len
(
sM_list
)):
sL
,
mL
=
sM_list
[
i
]
sR
,
mR
=
sM_list
[
j
]
rval
=
beta_L_plus_alpha_M
(
sL
,
mL
,
sR
,
mR
)
if
rval
:
assert
len
(
rval
)
==
1
inputs_without_ij
=
\
[
input
for
k
,
input
in
enumerate
(
node
.
inputs
)
if
k
not
in
(
i
,
j
)]
return
[
T
.
add
(
*
(
inputs_without_ij
+
rval
))]
return
False
return
False
if
JOSEPHS_BUG_SOLVED
:
register_specialize
(
local_dot_to_dot22
)
register_specialize
(
local_add_to_gemm
)
theano/tensor/elemwise.py
浏览文件 @
b6b2c608
...
@@ -316,7 +316,7 @@ class Elemwise(Op):
...
@@ -316,7 +316,7 @@ class Elemwise(Op):
scalars
scalars
* inplace_pattern: a dictionary that maps the index of an output to the
* inplace_pattern: a dictionary that maps the index of an output to the
index of an input so the output is calculated inplace using
index of an input so the output is calculated inplace using
the input's storage.
the input's storage.
(Just like destroymap, but without the lists.)
"""
"""
self
.
name
=
name
self
.
name
=
name
self
.
scalar_op
=
scalar_op
self
.
scalar_op
=
scalar_op
...
@@ -357,16 +357,21 @@ class Elemwise(Op):
...
@@ -357,16 +357,21 @@ class Elemwise(Op):
args
.
append
(
input
)
args
.
append
(
input
)
else
:
else
:
# TODO: use LComplete instead
# TODO: use LComplete instead
args
.
append
(
DimShuffle
(
input
.
type
.
broadcastable
,
[
'x'
]
*
difference
+
range
(
length
),
inplace
=
True
)(
input
))
args
.
append
(
DimShuffle
(
input
.
type
.
broadcastable
,
[
'x'
]
*
difference
+
range
(
length
),
inplace
=
True
)(
input
))
inputs
=
args
inputs
=
args
# # Following conditions should always be true?
#HERE: all the broadcast dims have the same length now
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
#cleverness: we iterate over the first, second, third broadcast flag of all inputs in
#parallel... the all() gives us each output broadcastable bit in turn.
#it is multiplied by nout because Elemwise supports multiple outputs (nout of them)
out_broadcastables
=
[[
all
(
bcast
)
for
bcast
in
zip
(
*
[
input
.
type
.
broadcastable
for
input
in
inputs
])]]
*
shadow
.
nout
out_broadcastables
=
[[
all
(
bcast
)
for
bcast
in
zip
(
*
[
input
.
type
.
broadcastable
for
input
in
inputs
])]]
*
shadow
.
nout
#inplace_pattern maps output idx -> input idx
inplace_pattern
=
self
.
inplace_pattern
inplace_pattern
=
self
.
inplace_pattern
if
inplace_pattern
:
if
inplace_pattern
:
for
overwriter
,
overwritten
in
inplace_pattern
.
items
():
for
overwriter
,
overwritten
in
inplace_pattern
.
items
():
...
@@ -374,21 +379,32 @@ class Elemwise(Op):
...
@@ -374,21 +379,32 @@ class Elemwise(Op):
if
ib
and
not
ob
:
if
ib
and
not
ob
:
raise
ValueError
(
"Operation cannot be done inplace on an input with broadcasted dimensions."
)
raise
ValueError
(
"Operation cannot be done inplace on an input with broadcasted dimensions."
)
out_dtypes
=
[
o
.
type
.
dtype
for
o
in
shadow
.
outputs
]
out_dtypes
=
[
o
.
type
.
dtype
for
o
in
shadow
.
outputs
]
if
any
(
inputs
[
i
]
.
type
.
dtype
!=
out_dtypes
[
o
]
for
i
,
o
in
inplace_pattern
.
items
()):
if
any
(
inputs
[
i
]
.
type
.
dtype
!=
out_dtypes
[
o
]
for
o
,
i
in
inplace_pattern
.
items
()):
raise
TypeError
(
"Cannot do an inplace operation on incompatible data types."
,
[
i
.
type
.
dtype
for
i
in
inputs
],
out_dtypes
)
raise
TypeError
(
"Cannot do an inplace operation on incompatible data types."
,
([
i
.
type
.
dtype
for
i
in
inputs
],
out_dtypes
,
inplace_pattern
))
outputs
=
[
Tensor
(
dtype
=
dtype
,
broadcastable
=
broadcastable
)()
for
dtype
,
broadcastable
in
zip
(
out_dtypes
,
out_broadcastables
)]
outputs
=
[
Tensor
(
dtype
=
dtype
,
broadcastable
=
broadcastable
)()
for
dtype
,
broadcastable
in
zip
(
out_dtypes
,
out_broadcastables
)]
return
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
outputs
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
scalar_op
==
other
.
scalar_op
and
self
.
inplace_pattern
==
other
.
inplace_pattern
if
type
(
self
)
==
type
(
other
):
items
=
self
.
inplace_pattern
.
items
()
other_items
=
other
.
inplace_pattern
.
items
()
items
.
sort
()
other_items
.
sort
()
return
self
.
scalar_op
==
other
.
scalar_op
and
items
==
other_items
return
False
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
self
.
scalar_op
)
^
hash
(
tuple
(
self
.
inplace_pattern
.
items
()))
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
return
hash
(
self
.
scalar_op
)
^
hash
(
tuple
(
items
))
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
name
is
None
:
if
self
.
name
is
None
:
if
self
.
inplace_pattern
:
if
self
.
inplace_pattern
:
return
"Elemwise{
%
s}
%
s"
%
(
self
.
scalar_op
,
str
(
self
.
inplace_pattern
))
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
return
"Elemwise{
%
s}
%
s"
%
(
self
.
scalar_op
,
str
(
items
))
else
:
else
:
return
"Elemwise{
%
s}"
%
(
self
.
scalar_op
)
return
"Elemwise{
%
s}"
%
(
self
.
scalar_op
)
else
:
else
:
...
@@ -467,6 +483,7 @@ class Elemwise(Op):
...
@@ -467,6 +483,7 @@ class Elemwise(Op):
storage
[
0
]
=
odat
storage
[
0
]
=
odat
else
:
else
:
for
i
,
(
output
,
storage
)
in
enumerate
(
zip
(
node
.
outputs
,
output_storage
)):
for
i
,
(
output
,
storage
)
in
enumerate
(
zip
(
node
.
outputs
,
output_storage
)):
#i is an output idx
if
i
in
self
.
inplace_pattern
:
if
i
in
self
.
inplace_pattern
:
odat
=
inputs
[
self
.
inplace_pattern
[
i
]]
odat
=
inputs
[
self
.
inplace_pattern
[
i
]]
else
:
else
:
...
@@ -500,7 +517,7 @@ class Elemwise(Op):
...
@@ -500,7 +517,7 @@ class Elemwise(Op):
defines
=
""
defines
=
""
undefs
=
""
undefs
=
""
dmap
=
dict
([(
node
.
outputs
[
i
],
[
node
.
inputs
[
o
]])
for
i
,
o
in
self
.
inplace_pattern
.
items
()])
dmap
=
dict
([(
node
.
outputs
[
o
],
[
node
.
inputs
[
i
]])
for
o
,
i
in
self
.
inplace_pattern
.
items
()])
idtypes
=
[
input
.
type
.
dtype_specs
()[
1
]
for
input
in
inputs
]
idtypes
=
[
input
.
type
.
dtype_specs
()[
1
]
for
input
in
inputs
]
...
...
theano/tensor/opt.py
浏览文件 @
b6b2c608
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
..
import
gof
from
..
import
gof
from
..gof
import
opt
from
..gof
import
opt
,
InconsistencyError
from
elemwise
import
Elemwise
,
DimShuffle
from
elemwise
import
Elemwise
,
DimShuffle
from
..
import
scalar
from
..
import
scalar
import
basic
as
T
import
basic
as
T
...
@@ -32,7 +32,8 @@ def in2out(*local_opts, **kwargs):
...
@@ -32,7 +32,8 @@ def in2out(*local_opts, **kwargs):
def
_insert_inplace_optimizer
(
env
):
@gof.optimizer
def
insert_inplace_optimizer
(
env
):
"""
"""
Usage: inplace_optimizer.optimize(env)
Usage: inplace_optimizer.optimize(env)
...
@@ -59,17 +60,18 @@ def _insert_inplace_optimizer(env):
...
@@ -59,17 +60,18 @@ def _insert_inplace_optimizer(env):
try
:
try
:
new
=
Elemwise
(
new
=
Elemwise
(
op
.
scalar_op
.
__class__
(
op
.
scalar_op
.
__class__
(
scalar
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
xrange
(
len
(
node
.
outputs
))])),
scalar
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
\
for
i
in
xrange
(
len
(
node
.
outputs
))])),
inplace_pattern
)
.
make_node
(
*
node
.
inputs
)
inplace_pattern
)
.
make_node
(
*
node
.
inputs
)
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
new
.
outputs
))
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
new
.
outputs
))
except
Exception
,
e
:
except
(
ValueError
,
TypeError
,
InconsistencyError
)
,
e
:
continue
continue
candidate_inputs
.
remove
(
candidate_input
)
candidate_inputs
.
remove
(
candidate_input
)
node
=
new
node
=
new
baseline
=
inplace_pattern
baseline
=
inplace_pattern
break
break
insert_inplace_optimizer
=
gof
.
optimizer
(
_insert_inplace_optimizer
)
compile
.
optdb
.
register
(
'inplace_opt'
,
insert_inplace_optimizer
,
75
,
'fast_run'
,
'inplace'
)
def
register_canonicalize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_canonicalize
(
lopt
,
*
tags
,
**
kwargs
):
name
=
(
kwargs
and
kwargs
.
pop
(
'name'
))
or
lopt
.
__name__
name
=
(
kwargs
and
kwargs
.
pop
(
'name'
))
or
lopt
.
__name__
...
@@ -310,7 +312,7 @@ def local_fill_cut(node):
...
@@ -310,7 +312,7 @@ def local_fill_cut(node):
register_canonicalize
(
local_fill_cut
)
register_canonicalize
(
local_fill_cut
)
register_canonicalize
(
gof
.
OpRemove
(
T
.
tensor_copy
),
name
=
'remove_tensor_copy'
)
#register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' ) #DEBUG
@gof.local_optimizer
([
None
,
T
.
fill
])
@gof.local_optimizer
([
None
,
T
.
fill
])
def
local_fill_sink
(
node
):
def
local_fill_sink
(
node
):
...
@@ -650,38 +652,6 @@ def local_mul_specialize(node):
...
@@ -650,38 +652,6 @@ def local_mul_specialize(node):
return
False
return
False
register_specialize
(
local_mul_specialize
)
register_specialize
(
local_mul_specialize
)
if
0
:
#TODO: replace this with a c version of any InplaceDimShuffle
class
_TransposeInplace
(
T
.
Op
):
view_map
=
{
0
:
[
0
]}
def
make_node
(
self
,
input
):
return
T
.
Apply
(
self
,
[
input
],
[
T
.
tensor
(
dtype
=
input
.
type
.
dtype
,
broadcastable
=
reversed
(
input
.
type
.
broadcastable
))])
def
perform
(
self
,
node
,
(
x
,
),
(
z
,
)):
z
[
0
]
=
x
.
T
def
c_code
(
self
,
node
,
name
,
(
x
,
),
(
z
,
),
sub
):
return
"""
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(
%(x)
s, NULL);
if (
%(z)
s) {
Py_XDECREF(
%(z)
s);
}
%(z)
s = transposed;
"""
%
locals
()
def
__str__
(
self
):
return
"_TransposeInplace"
_transpose_inplace
=
_TransposeInplace
()
@gof.local_optimizer
([
T
.
DimShuffle
([
False
,
False
],[
1
,
0
],
inplace
=
True
)])
def
local_dimshuffle_transposeinplace
(
node
):
if
node
.
op
==
T
.
DimShuffle
([
False
,
False
],[
1
,
0
],
inplace
=
True
):
return
[
_transpose_inplace
(
node
.
inputs
[
0
])]
return
False
register_specialize
(
local_dimshuffle_transposeinplace
)
register_canonicalize
(
local_mul_canonizer
,
name
=
'local_mul_canonizer'
)
register_canonicalize
(
local_mul_canonizer
,
name
=
'local_mul_canonizer'
)
...
@@ -844,287 +814,3 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y
...
@@ -844,287 +814,3 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y
register_canonicalize
(
local_transposed_dot
,
name
=
'local_transposed_dot'
)
register_canonicalize
(
local_transposed_dot
,
name
=
'local_transposed_dot'
)
# def _math_optimizer():
# pass_1 = in2out(local_fill_sink)
# pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
# pass_3 = out2in(local_subtensor_make_vector, local_fill_cut)
# canonizer = in2out(local_add_canonizer,
# local_mul_canonizer,
# local_fill_sink)
# pass_4 = out2in(local_greedy_distributor)
# return gof.SeqOptimizer(pass_1,
# pass_2,
# pass_3,
# neg_to_mul,
# canonizer,
# pass_4,
# mul_to_neg)
# math_optimizer = _math_optimizer()
# compile.register_optimizer('math',
# gof.MergeOptMerge(
# gof.PureThenInplaceOptimizer(
# math_optimizer,
# inplace_optimizer)))
# compile.register_mode('SANITY_CHECK', compile.Mode('c&py', 'math'))
# compile.register_mode('FAST_RUN', compile.Mode('c|py', 'math'))
# compile.register_mode('EXPENSIVE_OPTIMIZATIONS', compile.Mode('c|py', 'math'))
# @gof.local_optimizer
# def local_clique_fusion(node):
# aaaaaaaaaaaaaaaaaaaaaaa
# def find_cliques(env, through_broadcast = False):
# """
# Usage: find_cliques(env, through_broadcast = False)
# Returns a list of pairs where each pair contains a list
# of inputs and a list of outputs such that Env(inputs, outputs)
# contains nothing but Broadcast Ops.
# If through_broadcast is False, the cliques will only be
# allowed to broadcast over the inputs, which means, for
# example, that vector operations will not be mixed with
# matrix operations.
# """
# def seek_from(r):
# # walks through the graph until it encounters a
# # non-Broadcast operation or (if through_broadcast
# # is False) a Result which needs to be broadcasted.
# op = r.owner
# if env.edge(r) \
# or not isinstance(op, Broadcast) \
# or len(op.outputs) > 1:
# # todo: handle multiple-output broadcast ops
# # (needs to update the clique's outputs)
# return None
# ret = set()
# if not through_broadcast:
# # check each dimension over all the inputs - if the broadcastable
# # fields are not all 0 or all 1 for a particular dimension, then
# # broadcasting will be performed along it on the inputs where the
# # value is 1 and we will stop.
# if any(any(bc) and not all(bc)
# for bc in zip(*[input.broadcastable for input in op.inputs])):
# ret.update(op.inputs)
# return ret
# for input in op.inputs:
# res = seek_from(input)
# if res is None:
# # input is a leaf of our search
# ret.add(input)
# else:
# ret.update(res)
# return ret
# cliques = []
# def find_cliques_helper(r):
# if env.edge(r):
# return
# clique_inputs = seek_from(r)
# if clique_inputs is None:
# # Not in a clique, keep going
# op = r.owner
# if op is not None:
# for input in op.inputs:
# find_cliques_helper(input)
# else:
# # We found a clique, add it to the list and
# # jump to the leaves.
# cliques.append((clique_inputs, [r]))
# for input in clique_inputs:
# find_cliques_helper(input)
# for output in env.outputs:
# find_cliques_helper(output)
# # todo: merge the cliques if possible
# return cliques
# class CliqueOptimizer(opt.Optimizer):
# """
# Usage: CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = False).optimize(env)
# Finds cliques of Broadcast operations in the env and does either
# or both of two things:
# * Apply scalar_optimizer on the clique as if the clique was a
# group of scalar operations. scalar_optimizer can be any optimization
# which applies on scalars. If it is None, no optimization is done.
# * Replace the clique with a single Op, optimized to perform the
# computations properly. If make_composite is False, no such replacement
# is done.
# Note: it is recommended to run the lift_dimshuffle optimization before
# this one.
# """
# def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
# self.through_broadcast = through_broadcast
# self.scalar_optimizer = scalar_optimizer
# self.make_composite = make_composite
# def apply(self, env):
# if self.scalar_optimizer is None and not self.make_composite:
# # there's nothing to do with the cliques...
# return
# cliques = find_cliques(env, self.through_broadcast)
# opt = self.scalar_optimizer
# def build_scalar_clique(r, env, equiv):
# # Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# # structure and equivalent operations. equiv contains the mapping.
# if r in equiv:
# return equiv[r]
# op = r.owner
# if env.edge(r):
# # For each leave we make a Scalar of the corresponding dtype
# s = scalar.Scalar(dtype = r.dtype)
# _r = r
# if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
# _r = r.owner.inputs[0]
# if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
# and _r.broadcastable == ():
# # If we have a constant tensor we map it to a constant scalar.
# s.data = _r.data
# s.constant = True
# equiv[r] = s
# return s
# s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
# equiv[op] = s_op
# for output, s_output in zip(op.outputs, s_op.outputs):
# equiv[output] = s_output
# return equiv[r]
# for c_in, c_out in cliques:
# equiv = dict()
# g = Env(c_in, c_out)
# for output in c_out:
# build_scalar_clique(output, g, equiv)
# s_g = Env([equiv[r] for r in g.inputs],
# [equiv[r] for r in g.outputs])
# if opt is not None:
# equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
# for k, v in equiv.items():
# equiv2[v] = k
# def transform(op, equiv):
# # We get a scalar op and we return an equivalent op on tensors.
# return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
# s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
# opt.optimize(s_g)
# if self.make_composite:
# def follow_inplace(r):
# # Tries to find the earliest r2 in g such that r destroys r2
# # If no such r2 is found, returns None
# op = r.owner
# if op is None or r in g.inputs or r in g.orphans():
# return None
# assert isinstance(op, Broadcast)
# destroyed = op.destroy_map().get(r, None)
# if destroyed is None:
# return None
# else:
# r2 = destroyed[0]
# ret = follow_inplace(r2)
# if ret is None:
# return r2
# else:
# return ret
# inplace_pattern = {}
# for i, output in enumerate(g.outputs):
# destroyed = follow_inplace(output)
# if destroyed is not None and destroyed in g.inputs:
# # we transfer the inplace operation only if it is
# # an input that is destroyed
# inplace_pattern[i] = g.inputs.index(destroyed)
# C = scalar.composite(s_g.inputs, s_g.outputs)
# ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
# env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
# def sync_to(target, equiv, transform):
# """
# Usage: sync_to(target, equiv, transform)
# * target: an Env
# * equiv: a dictionary that maps results and ops to results and ops
# in target
# * transform: a function that takes (op, equiv) as inputs and
# returns a new op.
# Returns a Feature that can be added to an Env and mirrors all
# modifications to that env with modifications to the target env.
# """
# class Synchronize(gof.Listener, gof.Constraint):
# def __init__(self, source):
# self.source = source
# self.target = target
# self.equiv = equiv
# self.transform = transform
# self.inconsistencies = []
# def on_import(self, op1):
# if op1 not in self.equiv:
# op2 = self.transform(op1, self.equiv)
# self.equiv[op1] = op2
# for o1, o2 in zip(op1.outputs, op2.outputs):
# self.equiv[o1] = o2
# def on_prune(self, op1):
# if op1 in self.equiv:
# op2 = self.equiv[op1]
# del self.equiv[op1]
# for o1, o2 in zip(op1.outputs, op2.outputs):
# del self.equiv[o1]
# def on_rewire(self, clients1, r1, new_r1):
# if (new_r1, r1) in self.inconsistencies:
# self.inconsistencies.remove((new_r1, r1))
# return
# if not self.source.clients(r1):
# try:
# target.replace(self.equiv[r1], self.equiv[new_r1])
# except:
# self.inconsistencies.append((r1, new_r1))
# def validate(self):
# if self.inconsistencies:
# raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
# return True
# return Synchronize
theano/tensor/tests/test_blas.py
浏览文件 @
b6b2c608
...
@@ -3,10 +3,13 @@ import theano.tensor as T
...
@@ -3,10 +3,13 @@ import theano.tensor as T
from
...gof
import
Env
from
...gof
import
Env
import
numpy
import
numpy
from
theano.tensor.blas
import
*
from
theano.tensor.blas
import
*
from
theano.tensor.blas
import
_
as_scalar
,
_dot22
,
_is_real_matrix
from
theano.tensor.blas
import
_
dot22
,
res_is_a
from
unittest
import
TestCase
from
unittest
import
TestCase
from
copy
import
copy
from
copy
import
copy
_as_scalar
=
GemmLocalOptimizer
.
_as_scalar
_is_real_matrix
=
GemmLocalOptimizer
.
_is_real_matrix
from
theano
import
In
,
Out
from
theano
import
In
,
Out
from
.test_basic
import
(
_approx_eq
,
as_tensor
,
function
,
from
.test_basic
import
(
_approx_eq
,
as_tensor
,
function
,
compile
,
value
,
constant
,
inplace
,
eval_outputs
)
compile
,
value
,
constant
,
inplace
,
eval_outputs
)
...
@@ -185,6 +188,15 @@ class t_gemm(TestCase):
...
@@ -185,6 +188,15 @@ class t_gemm(TestCase):
return
return
self
.
fail
()
self
.
fail
()
def
test_res_is_a
():
X
,
Y
,
Z
,
a
,
b
=
XYZab
()
assert
not
res_is_a
(
a
,
T
.
sqrt
)
assert
not
res_is_a
(
a
+
a
,
T
.
sqrt
)
assert
res_is_a
(
T
.
sqrt
(
a
+
a
),
T
.
sqrt
)
#leave the maxclients stuff untested because it requires being in an env.
class
t_as_scalar
(
TestCase
):
class
t_as_scalar
(
TestCase
):
def
test0
(
self
):
def
test0
(
self
):
"""Test that it works on scalar constants"""
"""Test that it works on scalar constants"""
...
@@ -227,85 +239,167 @@ class T_real_matrix(TestCase):
...
@@ -227,85 +239,167 @@ class T_real_matrix(TestCase):
self
.
failUnless
(
_is_real_matrix
(
T
.
DimShuffle
([
False
,
False
],
[
1
,
0
])(
T
.
dmatrix
())))
self
.
failUnless
(
_is_real_matrix
(
T
.
DimShuffle
([
False
,
False
],
[
1
,
0
])(
T
.
dmatrix
())))
self
.
failUnless
(
not
_is_real_matrix
(
T
.
DimShuffle
([
False
],
[
'x'
,
0
])(
T
.
dvector
())))
self
.
failUnless
(
not
_is_real_matrix
(
T
.
DimShuffle
([
False
],
[
'x'
,
0
])(
T
.
dvector
())))
if
JOSEPHS_BUG_SOLVED
:
def
fail
(
msg
):
class
T_gemm_opt
(
TestCase
):
print
'FAIL'
,
msg
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
assert
False
functions compute the same things as the originals."""
def
XYZab
(
self
):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def
XYZab
():
return
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
return
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
def
just_gemm
(
self
,
i
,
o
,
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
()]):
class
Failure
(
Exception
):
def
on_fail
():
pass
for
node
in
f
.
maker
.
env
.
toposort
():
print
'GRAPH'
,
node
self
.
fail
()
class
Warning
(
Exception
):
pass
def
just_gemm
(
i
,
o
,
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
()]):
try
:
f
=
function
([
In
(
ii
,
mutable
=
True
)
for
ii
in
i
],
o
,
mode
=
'FAST_RUN'
)
f
=
function
([
In
(
ii
,
mutable
=
True
)
for
ii
in
i
],
o
,
mode
=
'FAST_RUN'
)
for
node
in
f
.
maker
.
env
.
nodes
:
for
node
in
f
.
maker
.
env
.
nodes
:
if
node
.
op
==
T
.
dot
:
on_fail
(
)
if
node
.
op
==
T
.
dot
:
raise
Warning
(
'dot in graph'
)
if
node
.
op
==
_dot22
:
on_fail
(
)
if
node
.
op
==
_dot22
:
raise
Warning
(
'_dot22 in graph'
)
g
=
function
(
i
,
o
,
mode
=
'FAST_COMPILE'
)
g
=
function
(
i
,
o
,
mode
=
compile
.
Mode
(
linker
=
'py'
,
optimizer
=
None
)
)
for
node
in
g
.
maker
.
env
.
nodes
:
for
node
in
g
.
maker
.
env
.
nodes
:
if
node
.
op
==
gemm
:
on_fail
(
)
if
node
.
op
==
gemm
:
raise
Warning
(
'gemm in graph'
)
rng
=
numpy
.
random
.
RandomState
(
234
)
rng
=
numpy
.
random
.
RandomState
(
234
)
r0
=
f
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
r0
=
f
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
rng
=
numpy
.
random
.
RandomState
(
234
)
rng
=
numpy
.
random
.
RandomState
(
234
)
r1
=
g
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
r1
=
g
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
if
numpy
.
max
(
numpy
.
abs
(
r0
[
0
]
-
r1
[
0
]))
>
1.0e-8
:
max_abs_err
=
numpy
.
max
(
numpy
.
abs
(
r0
[
0
]
-
r1
[
0
]))
self
.
fail
()
if
max_abs_err
>
1.0e-8
:
raise
Failure
(
'GEMM is computing the wrong output. max_rel_err ='
,
max_abs_err
)
except
Failure
:
for
node
in
f
.
maker
.
env
.
toposort
():
print
'GRAPH'
,
node
raise
except
Warning
:
for
node
in
f
.
maker
.
env
.
toposort
():
print
'GRAPH'
,
node
def
test0
(
self
):
def
test_gemm_opt0
():
"""Many subgraphs whose dots can be eliminated"""
"""Many subgraphs whose dots can be eliminated"""
X
,
Y
,
Z
,
a
,
b
=
self
.
XYZab
()
X
,
Y
,
Z
,
a
,
b
=
XYZab
()
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
T
.
dot
(
X
,
Y
)
*
a
+
Z
*
b
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
T
.
dot
(
X
,
Y
)
*
a
+
Z
*
b
])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
a
*
T
.
dot
(
X
,
Y
)
+
b
*
Z
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
a
*
T
.
dot
(
X
,
Y
)
+
b
*
Z
])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
+
a
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
+
a
*
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
T
.
dot
(
X
,
Y
)
*
a
-
Z
*
b
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
T
.
dot
(
X
,
Y
)
*
a
-
Z
*
b
])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
a
*
T
.
dot
(
X
,
Y
)
-
b
*
Z
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
a
*
T
.
dot
(
X
,
Y
)
-
b
*
Z
])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
-
a
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
-
a
*
T
.
dot
(
X
,
Y
)])
#with transposes (transposes should be pushed through dot in canonicalize)
#with transposes (transposes should be pushed through dot in canonicalize)
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
.
T
-
a
*
T
.
dot
(
Y
.
T
,
X
.
T
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
.
T
-
a
*
T
.
dot
(
Y
.
T
,
X
.
T
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
.
T
+
a
*
b
*
T
.
dot
(
X
,
Y
)
.
T
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
b
*
Z
.
T
+
a
*
b
*
T
.
dot
(
X
,
Y
)
.
T
])
#with N multiplications instead of just one
#with N multiplications instead of just one
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
*
b
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
*
b
])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
*
b
+
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
*
b
+
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
-
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
*
b
])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
-
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
*
b
])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
-
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
-
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
*
b
-
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
*
b
-
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
-
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
-
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
# with > 2 terms in the overall addition
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
Z
+
T
.
dot
(
X
,
Y
)
+
Z
])
def
test_double_gemm
(
self
):
def
test_gemm_opt_double_gemm
(
):
"""This is the pattern that shows up in the autoencoder"""
"""This is the pattern that shows up in the autoencoder"""
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
R
,
S
,
c
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
()
R
,
S
,
c
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
()
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
,
R
,
S
,
c
],
[
Z
*
c
+
a
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
R
,
S
)
.
T
],
just_gemm
([
X
,
Y
,
Z
,
a
,
b
,
R
,
S
,
c
],
[
Z
*
c
+
a
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
R
,
S
)
.
T
],
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
(),
(
5
,
9
),
(
9
,
4
),
()])
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
(),
(
5
,
9
),
(
9
,
4
),
()])
def
wishlist
(
self
):
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
(),
(
5
,
9
),
(
9
,
4
),
()]
i
=
[
X
,
Y
,
Z
,
a
,
b
,
R
,
S
,
c
]
o
=
[
a
*
T
.
dot
(
X
,
Y
)
+
gemm
(
Z
,
b
,
S
.
T
,
R
.
T
,
1.0
)]
try
:
f
=
function
([
In
(
ii
,
mutable
=
True
)
for
ii
in
i
],
o
,
mode
=
'FAST_RUN'
)
for
node
in
f
.
maker
.
env
.
nodes
:
if
node
.
op
==
T
.
dot
:
raise
Failure
(
'dot in graph'
)
if
node
.
op
==
_dot22
:
raise
Failure
(
'_dot22 in graph'
)
g
=
function
(
i
,
o
,
mode
=
compile
.
Mode
(
linker
=
'py'
,
optimizer
=
None
))
#for node in g.maker.env.nodes:
# if node.op == gemm: raise Failure('gemm in graph')
rng
=
numpy
.
random
.
RandomState
(
234
)
r0
=
f
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
rng
=
numpy
.
random
.
RandomState
(
234
)
r1
=
g
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
max_abs_err
=
numpy
.
max
(
numpy
.
abs
(
r0
[
0
]
-
r1
[
0
]))
if
max_abs_err
>
1.0e-8
:
raise
Failure
(
'GEMM is computing the wrong output. max_rel_err ='
,
max_abs_err
)
except
Failure
:
for
node
in
f
.
maker
.
env
.
toposort
():
print
'GRAPH'
,
node
raise
def
wishlist_gemm_opt
():
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
#with >2 additions of the same T.dot(X,Y term
#with >2 additions of the same T.dot(X,Y term
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
T
.
dot
(
X
,
Y
)
+
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
T
.
dot
(
X
,
Y
)
+
T
.
dot
(
X
,
Y
)])
self
.
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
X
,
Y
)])
def
test_gemm_with_vector
():
"""Many subgraphs whose dots can be eliminated.
This adds a vector two the previous test, which triggers the long-sought GEMM bug.
"""
X
,
Y
,
Z
,
a
,
b
=
XYZab
()
v
=
T
.
vector
()
def
my_just_gemm
(
o
):
i
=
[
X
,
Y
,
Z
,
a
,
b
,
v
]
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
(),
(
5
,)]
rval
=
just_gemm
(
i
,
o
,
ishapes
=
ishapes
)
my_just_gemm
([
v
+
T
.
dot
(
X
,
Y
)
*
a
+
Z
*
b
])
my_just_gemm
([
v
+
a
*
T
.
dot
(
X
,
Y
)
+
b
*
Z
])
my_just_gemm
([
v
+
b
*
Z
+
a
*
T
.
dot
(
X
,
Y
)])
my_just_gemm
([
v
+
T
.
dot
(
X
,
Y
)
*
a
-
Z
*
b
])
my_just_gemm
([
v
+
a
*
T
.
dot
(
X
,
Y
)
-
b
*
Z
])
my_just_gemm
([
v
+
b
*
Z
-
a
*
T
.
dot
(
X
,
Y
)])
def
test_vector_stuff
(
self
):
#with N multiplications instead of just one
my_just_gemm
([
v
+
(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
*
b
])
my_just_gemm
([
v
+
Z
+
T
.
dot
(
X
,
Y
)])
my_just_gemm
([
v
+
Z
*
b
+
T
.
dot
(
X
,
Y
)])
my_just_gemm
([
v
+
Z
+
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
my_just_gemm
([
v
+
(
b
*
b
)
*
Z
*
a
-
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
*
b
])
my_just_gemm
([
Z
-
T
.
dot
(
X
,
Y
)
+
v
])
my_just_gemm
([
Z
*
b
-
T
.
dot
(
X
,
Y
)
+
v
])
my_just_gemm
([
Z
-
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)
+
v
])
def
test_gemm_opt_vector_stuff
():
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
u
,
v
=
T
.
dvector
(),
T
.
dvector
()
u
,
v
=
T
.
dvector
(),
T
.
dvector
()
f
=
function
([
a
,
u
,
v
],
a
+
T
.
dot
(
u
,
v
),
mode
=
'FAST_RUN'
)
f
=
function
([
a
,
u
,
v
],
a
+
T
.
dot
(
u
,
v
),
mode
=
'FAST_RUN'
)
self
.
failIf
(
gemm
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
])
if
gemm
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
]:
raise
Failure
(
'gemm in graph'
)
f
=
function
([
a
,
u
,
X
,
Y
],
a
*
u
+
T
.
dot
(
X
,
Y
),
mode
=
'FAST_RUN'
)
f
=
function
([
a
,
u
,
X
,
Y
],
a
*
u
+
T
.
dot
(
X
,
Y
),
mode
=
'FAST_RUN'
)
self
.
failIf
(
gemm
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
])
if
(
gemm
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
]):
raise
Failure
(
'gemm in graph'
)
def
test_inplace0
():
#should fail to insert gemm because gemm would create cycles
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
R
,
S
,
c
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
()
f
=
function
([
X
,
Y
,
Z
,
a
,
b
,
R
,
S
,
c
],
[
Z
*
(
Z
*
c
+
a
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
R
,
S
)
.
T
)],
mode
=
'FAST_RUN'
)
if
(
gemm
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
]):
raise
Failure
(
'gemm in graph'
)
def
test_inplace1
():
X
,
Y
,
Z
,
a
,
b
=
XYZab
()
# with > 2 terms in the overall addition
f
=
function
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
Z
+
T
.
dot
(
X
,
Y
)],
mode
=
'FAST_RUN'
)
if
(
gemm
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
]):
raise
Failure
(
'gemm in graph'
)
theano/tensor/tests/test_joseph.py
浏览文件 @
b6b2c608
...
@@ -155,14 +155,14 @@ class QuadraticDenoisingAA(T.RModule):
...
@@ -155,14 +155,14 @@ class QuadraticDenoisingAA(T.RModule):
updates
=
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
zip
(
self
.
params
,
gradients
))
updates
=
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
zip
(
self
.
params
,
gradients
))
# INTERFACE METHODS
# INTERFACE METHODS
self
.
update
=
theano
.
Method
(
self
.
input
,
self
.
ncost
,
updates
)
#
self.update = theano.Method(self.input, self.ncost, updates)
self
.
compute_cost
=
theano
.
Method
(
self
.
input
,
self
.
cost
)
#
self.compute_cost = theano.Method(self.input, self.cost)
self
.
noisify
=
theano
.
Method
(
self
.
input
,
self
.
corrupted_input
)
#
self.noisify = theano.Method(self.input, self.corrupted_input)
self
.
reconstruction
=
theano
.
Method
(
self
.
input
,
self
.
output
)
#
self.reconstruction = theano.Method(self.input, self.output)
self
.
representation
=
theano
.
Method
(
self
.
input
,
self
.
hidden
)
#
self.representation = theano.Method(self.input, self.hidden)
self
.
reconstruction_through_noise
=
theano
.
Method
(
self
.
input
,
[
self
.
corrupted_input
,
self
.
noutput
])
#
self.reconstruction_through_noise = theano.Method(self.input, [self.corrupted_input, self.noutput])
self
.
validate
=
theano
.
Method
(
self
.
input
,
[
self
.
cost
,
self
.
output
])
#
self.validate = theano.Method(self.input, [self.cost, self.output])
def
_instance_initialize
(
self
,
obj
,
input_size
,
hidden_size
,
seed
,
lr
,
qfilter_relscale
):
def
_instance_initialize
(
self
,
obj
,
input_size
,
hidden_size
,
seed
,
lr
,
qfilter_relscale
):
"""
"""
...
@@ -291,16 +291,16 @@ class Module_Nclass(module.FancyModule):
...
@@ -291,16 +291,16 @@ class Module_Nclass(module.FancyModule):
#define the apply method
#define the apply method
self
.
pred
=
T
.
argmax
(
linear_output
,
axis
=
1
)
self
.
pred
=
T
.
argmax
(
linear_output
,
axis
=
1
)
self
.
apply
=
module
.
Method
([
self
.
input
],
self
.
pred
)
#
self.apply = module.Method([self.input], self.pred)
self
.
validate
=
module
.
Method
([
self
.
input
,
self
.
targ
],
[
self
.
cost
,
self
.
argmax
,
self
.
max_pr
])
#
self.validate = module.Method([self.input, self.targ], [self.cost, self.argmax, self.max_pr])
self
.
softmax_output
=
module
.
Method
([
self
.
input
],
self
.
softmax_unsupervised
)
#
self.softmax_output = module.Method([self.input], self.softmax_unsupervised)
if
self
.
params
:
if
self
.
params
:
gparams
=
T
.
grad
(
sum_xent
,
self
.
params
)
gparams
=
T
.
grad
(
sum_xent
,
self
.
params
)
self
.
update
=
module
.
Method
([
self
.
input
,
self
.
targ
],
sum_xent
,
#
self.update = module.Method([self.input, self.targ], sum_xent,
updates
=
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
zip
(
self
.
params
,
gparams
)))
#
updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
class
ConvolutionalMLPInstance
(
module
.
FancyModuleInstance
,
Loss01
):
class
ConvolutionalMLPInstance
(
module
.
FancyModuleInstance
,
Loss01
):
#initialize is called by Module.make
#initialize is called by Module.make
...
@@ -366,11 +366,6 @@ class ConvolutionalMLP(module.FancyModule):
...
@@ -366,11 +366,6 @@ class ConvolutionalMLP(module.FancyModule):
)
)
)
)
# to_update = []
# all_kits = []
# input_update = self.input_representations[0].update
# input_update.resolve_all()
for
i
in
self
.
inputs
[
1
:]:
for
i
in
self
.
inputs
[
1
:]:
self
.
input_representations
.
append
(
self
.
input_representations
.
append
(
QDAA
(
QDAA
(
...
@@ -411,11 +406,17 @@ class ConvolutionalMLP(module.FancyModule):
...
@@ -411,11 +406,17 @@ class ConvolutionalMLP(module.FancyModule):
]
+
self
.
hidden
.
qfilters
]
+
self
.
hidden
.
qfilters
input_pretraining_cost
=
sum
(
i
.
ncost
for
i
in
self
.
input_representations
)
input_pretraining_cost
=
sum
(
i
.
ncost
for
i
in
self
.
input_representations
)
hidden_pretraining_cost
=
self
.
hidden
.
ncost
hidden_pretraining_cost
=
self
.
hidden
.
ncost
input_pretraining_gradients
=
T
.
grad
(
input_pretraining_cost
,
input_pretraining_params
)
input_pretraining_gradients
=
T
.
grad
(
input_pretraining_cost
,
input_pretraining_params
)
hidden_pretraining_gradients
=
T
.
grad
(
hidden_pretraining_cost
,
hidden_pretraining_params
)
hidden_pretraining_gradients
=
T
.
grad
(
hidden_pretraining_cost
,
hidden_pretraining_params
)
pretraining_updates
=
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
zip
(
input_pretraining_params
,
input_pretraining_gradients
)
+
pretraining_updates
=
\
zip
(
hidden_pretraining_params
,
hidden_pretraining_gradients
))
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
\
self
.
pretraining_update
=
module
.
Method
(
self
.
inputs
,
[
input_pretraining_cost
,
hidden_pretraining_cost
],
pretraining_updates
)
zip
(
input_pretraining_params
,
input_pretraining_gradients
)
\
+
zip
(
hidden_pretraining_params
,
hidden_pretraining_gradients
))
self
.
pretraining_update
=
module
.
Method
(
self
.
inputs
,
[
input_pretraining_cost
,
hidden_pretraining_cost
],
pretraining_updates
)
finetuning_params
=
\
finetuning_params
=
\
[
self
.
input_representations
[
0
]
.
w1
,
self
.
input_representations
[
0
]
.
b1
]
+
self
.
input_representations
[
0
]
.
qfilters
+
\
[
self
.
input_representations
[
0
]
.
w1
,
self
.
input_representations
[
0
]
.
b1
]
+
self
.
input_representations
[
0
]
.
qfilters
+
\
...
@@ -426,9 +427,8 @@ class ConvolutionalMLP(module.FancyModule):
...
@@ -426,9 +427,8 @@ class ConvolutionalMLP(module.FancyModule):
finetuning_updates
=
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
zip
(
finetuning_params
,
finetuning_gradients
))
finetuning_updates
=
dict
((
p
,
p
-
self
.
lr
*
g
)
for
p
,
g
in
zip
(
finetuning_params
,
finetuning_gradients
))
self
.
finetuning_update
=
module
.
Method
(
self
.
inputs
+
[
self
.
targ
],
self
.
output
.
cost
,
finetuning_updates
)
self
.
finetuning_update
=
module
.
Method
(
self
.
inputs
+
[
self
.
targ
],
self
.
output
.
cost
,
finetuning_updates
)
#self.validate = module.Method(self.inputs + [self.targ], [self.output.cost, self.output.argmax, self.output.max_pr])
self
.
validate
=
module
.
Method
(
self
.
inputs
+
[
self
.
targ
],
[
self
.
output
.
cost
,
self
.
output
.
argmax
,
self
.
output
.
max_pr
])
#self.softmax_output = module.Method(self.inputs, self.output.softmax_unsupervised)
self
.
softmax_output
=
module
.
Method
(
self
.
inputs
,
self
.
output
.
softmax_unsupervised
)
def
create
(
window_size
=
3
,
def
create
(
window_size
=
3
,
input_dimension
=
9
,
input_dimension
=
9
,
...
@@ -462,15 +462,21 @@ JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
...
@@ -462,15 +462,21 @@ JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print
'JTEST'
,
JTEST
print
'JTEST'
,
JTEST
theano
.
compile
.
register_optimizer
(
'JTEST'
,
JTEST
)
theano
.
compile
.
register_optimizer
(
'JTEST'
,
JTEST
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
optimizer
=
eval
(
sys
.
argv
[
1
])
optimizer
=
eval
(
sys
.
argv
[
1
])
m
=
create
(
compile_mode
=
theano
.
Mode
(
linker
=
'c|py'
,
optimizer
=
optimizer
))
m
=
create
(
compile_mode
=
theano
.
Mode
(
linker
=
'c|py'
,
optimizer
=
optimizer
))
prog_str
=
[]
prog_str
=
[]
for
i
,
node
in
enumerate
(
m
.
finetuning_update
.
maker
.
env
.
toposort
()):
idx_of_node
=
{}
#print ' ', i, node
for
i
,
node
in
enumerate
(
m
.
pretraining_update
.
maker
.
env
.
toposort
()):
idx_of_node
[
node
]
=
i
if
False
and
i
>
-
1
:
print
' '
,
i
,
node
,
[(
ii
,
idx_of_node
.
get
(
ii
.
owner
,
'IN'
))
for
ii
in
node
.
inputs
]
prog_str
.
append
(
str
(
node
))
prog_str
.
append
(
str
(
node
))
print
"PROGRAM LEN
%
i HASH
%
i"
%
(
len
(
m
.
finetuning_update
.
maker
.
env
.
nodes
),
reduce
(
lambda
a
,
b
:
hash
(
a
)
^
hash
(
b
),
prog_str
))
#print input_pretraining_gradients[4].owner.inputs
#print input_pretraining_gradients[4].owner.inputs[1].owner.inputs
#sys.exit()
print
"PROGRAM LEN
%
i HASH
%
i"
%
(
len
(
m
.
pretraining_update
.
maker
.
env
.
nodes
),
reduce
(
lambda
a
,
b
:
hash
(
a
)
^
hash
(
b
),
prog_str
))
rng
=
N
.
random
.
RandomState
(
23904
)
rng
=
N
.
random
.
RandomState
(
23904
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论