Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f9a67241
提交
f9a67241
authored
9月 29, 2009
作者:
Olivier Delalleau
浏览文件
操作
浏览文件
下载
差异文件
Merged
上级
ab249ff2
2dc07279
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
246 行增加
和
111 行删除
+246
-111
theano-cache
bin/theano-cache
+5
-4
debugmode.py
theano/compile/debugmode.py
+14
-33
profilemode.py
theano/compile/profilemode.py
+5
-5
__init__.py
theano/compile/sandbox/__init__.py
+1
-1
cc.py
theano/gof/cc.py
+37
-21
cmodule.py
theano/gof/cmodule.py
+15
-5
basic.py
theano/scalar/basic.py
+14
-11
basic.py
theano/tensor/basic.py
+79
-26
blas.py
theano/tensor/blas.py
+10
-1
opt.py
theano/tensor/opt.py
+62
-3
test_basic.py
theano/tensor/tests/test_basic.py
+4
-1
没有找到文件。
bin/theano-c
ompiledir
→
bin/theano-c
ache
浏览文件 @
f9a67241
#!/usr/bin/env python
#!/usr/bin/env python
from
theano.gof
import
cc
import
sys
import
sys
compiledir
=
cc
.
get_compiledir
()
from
theano.gof.cc
import
get_compiledir
,
get_module_cache
if
len
(
sys
.
argv
)
==
1
:
if
len
(
sys
.
argv
)
==
1
:
print
compiledir
print
get_compiledir
()
elif
sys
.
argv
[
1
]
in
(
'clear'
):
elif
sys
.
argv
[
1
]
in
(
'clear'
):
cc
.
clear_compiledi
r
()
get_module_cache
()
.
clea
r
()
else
:
else
:
print
'command "
%
s" not recognized'
%
sys
.
argv
[
1
]
print
'command "
%
s" not recognized'
%
sys
.
argv
[
1
]
print
'Type "theano-cache" to print the cache location'
print
'Type "theano-cache clear" to erase the cache'
sys
.
exit
(
1
)
sys
.
exit
(
1
)
theano/compile/debugmode.py
浏览文件 @
f9a67241
"""Provides `DebugMode`, an evaluation mode for debugging theano internals."""
"""Provides `DebugMode`, an evaluation mode for debugging theano internals."""
__docformat__
=
"restructuredtext en"
__docformat__
=
"restructuredtext en"
import
time
,
copy
,
sys
,
copy_reg
,
gc
import
time
,
copy
,
sys
,
copy_reg
,
gc
,
os
from
StringIO
import
StringIO
from
StringIO
import
StringIO
import
numpy
import
numpy
...
@@ -298,7 +298,7 @@ class InvalidValueError(DebugModeError):
...
@@ -298,7 +298,7 @@ class InvalidValueError(DebugModeError):
def
_
debugprint
(
r
,
prefix
=
''
,
depth
=-
1
,
done
=
None
,
file
=
sys
.
stdout
):
def
debugprint
(
r
,
prefix
=
''
,
depth
=-
1
,
done
=
None
,
file
=
sys
.
stdout
):
"""Print the graph leading to `r` to given depth.
"""Print the graph leading to `r` to given depth.
:param r: Variable instance
:param r: Variable instance
...
@@ -322,7 +322,7 @@ def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout):
...
@@ -322,7 +322,7 @@ def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout):
if
id
(
a
)
not
in
done
:
if
id
(
a
)
not
in
done
:
done
.
add
(
id
(
a
))
done
.
add
(
id
(
a
))
for
i
in
a
.
inputs
:
for
i
in
a
.
inputs
:
_
debugprint
(
i
,
prefix
+
' '
,
depth
=
depth
-
1
,
done
=
done
,
file
=
file
)
debugprint
(
i
,
prefix
+
' '
,
depth
=
depth
-
1
,
done
=
done
,
file
=
file
)
else
:
else
:
#this is a variable
#this is a variable
print
>>
file
,
prefix
,
r
,
id
(
r
)
print
>>
file
,
prefix
,
r
,
id
(
r
)
...
@@ -772,12 +772,12 @@ class _VariableEquivalenceTracker(object):
...
@@ -772,12 +772,12 @@ class _VariableEquivalenceTracker(object):
append_reason
=
False
append_reason
=
False
if
append_reason
:
if
append_reason
:
# N.B. compute the
_
debugprint now, because future optimizations will change the
# N.B. compute the debugprint now, because future optimizations will change the
# graph
# graph
self
.
reasons
[
new_r
]
.
append
((
reason
self
.
reasons
[
new_r
]
.
append
((
reason
,
r
,
r
,
_
debugprint
(
r
,
prefix
=
' '
,
depth
=
6
,
file
=
StringIO
())
.
getvalue
()
,
debugprint
(
r
,
prefix
=
' '
,
depth
=
6
,
file
=
StringIO
())
.
getvalue
()
,
_
debugprint
(
new_r
,
prefix
=
' '
,
depth
=
6
,
file
=
StringIO
())
.
getvalue
()))
,
debugprint
(
new_r
,
prefix
=
' '
,
depth
=
6
,
file
=
StringIO
())
.
getvalue
()))
self
.
replaced_by
[
r
]
.
append
((
reason
,
new_r
))
self
.
replaced_by
[
r
]
.
append
((
reason
,
new_r
))
if
r
in
self
.
equiv
:
if
r
in
self
.
equiv
:
...
@@ -856,29 +856,10 @@ class _Linker(gof.link.LocalLinker):
...
@@ -856,29 +856,10 @@ class _Linker(gof.link.LocalLinker):
if
not
self
.
maker
.
mode
.
check_c_code
:
if
not
self
.
maker
.
mode
.
check_c_code
:
raise
utils
.
MethodNotDefined
()
raise
utils
.
MethodNotDefined
()
e
=
Env
(
*
graph
.
clone
(
node
.
inputs
,
node
.
outputs
))
e
=
Env
(
*
graph
.
clone
(
node
.
inputs
,
node
.
outputs
))
e
.
toposort
=
lambda
:
e
.
nodes
#WARNING: STOCHASTIC ORDER
e
.
toposort
=
lambda
:
e
.
nodes
#WARNING: STOCHASTIC ORDER
# Specifically... e.nodes is a set, but of only 1 element
if
any
(
isinstance
(
input
,
graph
.
Value
)
for
input
in
node
.
inputs
):
cl
=
CLinker
()
.
accept
(
e
,
[
r
for
r
,
r2
in
zip
(
e
.
outputs
,
node
.
outputs
)
if
r2
in
no_recycling
])
desc
=
None
else
:
desc
=
(
node
.
op
,
tuple
(
input
.
type
for
input
in
node
.
inputs
),
tuple
(
input
.
type
for
input
in
node
.
inputs
),
tuple
(
output
in
no_recycling
for
output
in
node
.
outputs
),
tuple
(
node
.
inputs
.
count
(
input
)
for
input
in
node
.
inputs
))
try
:
cl
=
self
.
__cache__
.
get
(
desc
)
except
Exception
,
exc
:
#print >> sys.stderr, "INFO: failed to hash %s: %s. Node will not be cached." % (node, exc)
cl
=
None
if
cl
is
None
:
cl
=
CLinker
()
.
accept
(
e
,
[
r
for
r
,
r2
in
zip
(
e
.
outputs
,
node
.
outputs
)
if
r2
in
no_recycling
])
if
desc
is
not
None
:
try
:
self
.
__cache__
[
desc
]
=
cl
except
:
pass
thunk
,
node_input_filters
,
node_output_filters
=
cl
.
make_thunk
(
thunk
,
node_input_filters
,
node_output_filters
=
cl
.
make_thunk
(
input_storage
=
node_input_storage
,
input_storage
=
node_input_storage
,
...
@@ -1371,27 +1352,27 @@ class DebugMode(Mode):
...
@@ -1371,27 +1352,27 @@ class DebugMode(Mode):
"""
"""
stability_patience
=
10
stability_patience
=
int
(
os
.
getenv
(
'THEANO_DEBUGMODE_PATIENCE'
,
10
))
"""
"""
When checking for the stability of optimization, recompile the graph this many times.
When checking for the stability of optimization, recompile the graph this many times.
"""
"""
check_c_code
=
True
check_c_code
=
bool
(
int
(
os
.
getenv
(
'THEANO_DEBUGMODE_CHECK_C'
,
1
)))
"""
"""
Should we evaluate (and check) the `c_code` implementations?
Should we evaluate (and check) the `c_code` implementations?
"""
"""
check_py_code
=
True
check_py_code
=
bool
(
int
(
os
.
getenv
(
'THEANO_DEBUGMODE_CHECK_PY'
,
1
)))
"""
"""
Should we evaluate (and check) the `perform` implementations?
Should we evaluate (and check) the `perform` implementations?
"""
"""
check_isfinite
=
True
check_isfinite
=
bool
(
int
(
os
.
getenv
(
'THEANO_DEBUGMODE_CHECK_FINITE'
,
1
)))
"""
"""
Should we check for (and complain about) NaN/Inf ndarray elements?
Should we check for (and complain about) NaN/Inf ndarray elements?
"""
"""
require_matching_strides
=
False
require_matching_strides
=
bool
(
int
(
os
.
getenv
(
'THEANO_DEBUGMODE_CHECK_STRIDES'
,
0
)))
"""
"""
Should we check for (and complain about) Ops whose python and C outputs are ndarrays with
Should we check for (and complain about) Ops whose python and C outputs are ndarrays with
different strides? (This can catch bugs, but is generally overly strict.)
different strides? (This can catch bugs, but is generally overly strict.)
...
...
theano/compile/profilemode.py
浏览文件 @
f9a67241
...
@@ -102,7 +102,7 @@ class ProfileMode(Mode):
...
@@ -102,7 +102,7 @@ class ProfileMode(Mode):
def
print_diff_summary
(
self
,
other
,
n_apply_to_print
=
15
,
n_ops_to_print
=
20
):
def
print_diff_summary
(
self
,
other
,
n_apply_to_print
=
15
,
n_ops_to_print
=
20
):
""" As print_summary, but print the
absolute
difference on two different profile mode.
""" As print_summary, but print the difference on two different profile mode.
TODO: Also we don't print the Apply-wise summary as it don't work for now.
TODO: Also we don't print the Apply-wise summary as it don't work for now.
TODO: make comparaison with gpu code.
TODO: make comparaison with gpu code.
...
@@ -119,7 +119,7 @@ class ProfileMode(Mode):
...
@@ -119,7 +119,7 @@ class ProfileMode(Mode):
for
a
,
ta
in
a_time
.
items
():
for
a
,
ta
in
a_time
.
items
():
r
.
setdefault
(
a
,
0
)
r
.
setdefault
(
a
,
0
)
tb
=
b_time
.
pop
(
a
,
0
)
tb
=
b_time
.
pop
(
a
,
0
)
r
[
a
]
+=
abs
(
ta
-
tb
)
r
[
a
]
+=
ta
-
tb
#they are missing in a
#they are missing in a
for
a
,
t
in
b_time
.
items
():
for
a
,
t
in
b_time
.
items
():
...
@@ -133,7 +133,7 @@ class ProfileMode(Mode):
...
@@ -133,7 +133,7 @@ class ProfileMode(Mode):
for
a
,
ta
in
a_time
.
items
():
for
a
,
ta
in
a_time
.
items
():
tb
=
b_time
.
pop
(
a
,
0
)
tb
=
b_time
.
pop
(
a
,
0
)
if
hasattr
(
a
,
'flops'
):
if
hasattr
(
a
,
'flops'
):
flops
[
a
]
=
a
bs
(
a
.
flops
*
a_call
[
a
]
/
ta
-
a
.
flops
*
b_call
[
a
]
/
tb
)
/
1e6
flops
[
a
]
=
a
.
flops
*
a_call
[
a
]
/
ta
-
a
.
flops
*
b_call
[
a
]
/
tb
/
1e6
#they are missing in a
#they are missing in a
for
b
,
tb
in
b_time
.
items
():
for
b
,
tb
in
b_time
.
items
():
...
@@ -142,8 +142,8 @@ class ProfileMode(Mode):
...
@@ -142,8 +142,8 @@ class ProfileMode(Mode):
return
flops
return
flops
local_time
=
abs
(
self
.
local_time
[
0
]
-
other
.
local_time
[
0
])
local_time
=
self
.
local_time
[
0
]
-
other
.
local_time
[
0
]
compile_time
=
abs
(
self
.
compile_time
-
other
.
compile_time
)
compile_time
=
self
.
compile_time
-
other
.
compile_time
apply_time
=
diff_dict
(
self
.
apply_time
,
other
.
apply_time
)
apply_time
=
diff_dict
(
self
.
apply_time
,
other
.
apply_time
)
apply_call
=
diff_dict
(
self
.
apply_call
,
other
.
apply_call
)
apply_call
=
diff_dict
(
self
.
apply_call
,
other
.
apply_call
)
op_time
=
diff_dict
(
self
.
op_time
,
other
.
op_time
)
op_time
=
diff_dict
(
self
.
op_time
,
other
.
op_time
)
...
...
theano/compile/sandbox/__init__.py
浏览文件 @
f9a67241
from
.sharedvalue
import
shared
from
.sharedvalue
import
shared
,
shared_constructor
from
.pfunc
import
pfunc
from
.pfunc
import
pfunc
theano/gof/cc.py
浏览文件 @
f9a67241
...
@@ -36,16 +36,12 @@ import cmodule
...
@@ -36,16 +36,12 @@ import cmodule
import
logging
import
logging
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
def
info
(
*
args
):
def
info
(
*
args
):
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
def
debug
(
*
args
):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
warning
(
*
args
):
def
warning
(
*
args
):
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
_logger
.
warning
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
warning
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
error
(
*
args
):
def
error
(
*
args
):
sys
.
stderr
.
write
(
'ERROR:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
_logger
.
error
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
error
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
from
theano.gof.callcache
import
CallCache
from
theano.gof.callcache
import
CallCache
...
@@ -526,6 +522,8 @@ class CLinker(link.Linker):
...
@@ -526,6 +522,8 @@ class CLinker(link.Linker):
# List of arg names for use in struct_gen. Note the call to uniq: duplicate inputs
# List of arg names for use in struct_gen. Note the call to uniq: duplicate inputs
# must only be passed once because they are mapped to the same name.
# must only be passed once because they are mapped to the same name.
# Duplicates are defined by (a is b), rather than (a==b) since Constant instances can
# compare equal to equivalent Constant instances.
args
=
[]
args
=
[]
args
+=
[
"storage_
%
s"
%
symbol
[
variable
]
for
variable
in
utils
.
uniq
(
self
.
inputs
+
self
.
outputs
+
self
.
orphans
)]
args
+=
[
"storage_
%
s"
%
symbol
[
variable
]
for
variable
in
utils
.
uniq
(
self
.
inputs
+
self
.
outputs
+
self
.
orphans
)]
...
@@ -783,15 +781,21 @@ class CLinker(link.Linker):
...
@@ -783,15 +781,21 @@ class CLinker(link.Linker):
``a`` is the topological position of the input's owner (-1 for graph inputs),
``a`` is the topological position of the input's owner (-1 for graph inputs),
``b`` is the index of the variable in the owner's output list.
``b`` is the index of the variable in the owner's output list.
The graph position of a Constant is defined as its signature.
The graph position of a Constant instance is defined as its signature, together with
two integers: the topological position of the first Apply using that Constant instance,
and the lowest index into that Apply's inputs that refers to that Constant. (These two
integers are a surrogate for the id() of the Constant. The integers are important
because merge-able constants have the same signature, but require separate containers
in C code.)
If the Op of any Apply in the Env does not have c_code_cache_ok()==True, then this
If the Op of any Apply in the Env does not have c_code_cache_ok()==True, then this
function raises a KeyError exception.
function raises a KeyError exception.
"""
"""
order
=
list
(
self
.
env
.
toposort
())
order
=
list
(
self
.
env
.
toposort
())
env_inputs_
se
t
=
dict
((
i
,
(
-
1
,
pos
))
for
pos
,
i
in
enumerate
(
self
.
env
.
inputs
))
env_inputs_
dic
t
=
dict
((
i
,
(
-
1
,
pos
))
for
pos
,
i
in
enumerate
(
self
.
env
.
inputs
))
env_computed_set
=
set
()
env_computed_set
=
set
()
constant_ids
=
dict
()
op_pos
=
{}
# Apply -> topological position
op_pos
=
{}
# Apply -> topological position
rval
=
[
'CLinker.cmodule_key'
]
# will be cast to tuple on return
rval
=
[
'CLinker.cmodule_key'
]
# will be cast to tuple on return
rval
.
append
(
tuple
(
self
.
compile_args
()))
rval
.
append
(
tuple
(
self
.
compile_args
()))
...
@@ -802,11 +806,15 @@ class CLinker(link.Linker):
...
@@ -802,11 +806,15 @@ class CLinker(link.Linker):
# - an env input
# - an env input
# - an output from a node in the Env
# - an output from a node in the Env
# - a Constant
# - a Constant
def
graphpos
(
i
):
def
graphpos
(
i
,
topological_pos
,
i_idx
):
if
isinstance
(
i
,
graph
.
Constant
):
if
isinstance
(
i
,
graph
.
Constant
):
return
i
.
signature
()
if
id
(
i
)
not
in
constant_ids
:
elif
i
in
env_inputs_set
:
constant_ids
[
id
(
i
)]
=
(
i
.
signature
(),
topological_pos
,
i_idx
)
return
env_inputs_set
[
i
]
return
constant_ids
[
id
(
i
)]
#print 'SIGNATURE', i.signature()
#return i.signature()
elif
i
in
env_inputs_dict
:
return
env_inputs_dict
[
i
]
else
:
else
:
if
i
.
owner
is
None
:
if
i
.
owner
is
None
:
assert
all
(
all
(
out
is
not
None
for
out
in
o
.
outputs
)
for
o
in
order
)
assert
all
(
all
(
out
is
not
None
for
out
in
o
.
outputs
)
for
o
in
order
)
...
@@ -814,15 +822,16 @@ class CLinker(link.Linker):
...
@@ -814,15 +822,16 @@ class CLinker(link.Linker):
raise
Exception
(
'what is this?'
,
(
i
,
type
(
i
),
i
.
clients
,
self
.
env
))
raise
Exception
(
'what is this?'
,
(
i
,
type
(
i
),
i
.
clients
,
self
.
env
))
return
(
op_pos
[
i
.
owner
],
i
.
owner
.
outputs
.
index
(
i
))
return
(
op_pos
[
i
.
owner
],
i
.
owner
.
outputs
.
index
(
i
))
for
opos
,
o
in
enumerate
(
order
):
for
node_pos
,
node
in
enumerate
(
order
):
version
.
append
(
o
.
op
.
c_code_cache_version_apply
(
o
))
version
.
append
(
node
.
op
.
c_code_cache_version_apply
(
node
))
for
i
in
o
.
inputs
:
for
i
in
node
.
inputs
:
version
.
append
(
i
.
type
.
c_code_cache_version
())
version
.
append
(
i
.
type
.
c_code_cache_version
())
for
i
in
o
.
outputs
:
for
o
in
node
.
outputs
:
version
.
append
(
i
.
type
.
c_code_cache_version
())
version
.
append
(
o
.
type
.
c_code_cache_version
())
rval
.
append
((
o
.
op
,
tuple
((
i
.
type
,
graphpos
(
i
))
for
i
in
o
.
inputs
)))
rval
.
append
((
node
.
op
,
tuple
((
i
.
type
,
graphpos
(
i
,
node_pos
,
ipos
))
op_pos
[
o
]
=
opos
for
ipos
,
i
in
enumerate
(
node
.
inputs
))))
env_computed_set
.
update
(
o
.
outputs
)
op_pos
[
node
]
=
node_pos
env_computed_set
.
update
(
node
.
outputs
)
for
v
in
version
:
for
v
in
version
:
if
not
v
:
#one of the ops or types here is unversioned
if
not
v
:
#one of the ops or types here is unversioned
...
@@ -855,12 +864,12 @@ class CLinker(link.Linker):
...
@@ -855,12 +864,12 @@ class CLinker(link.Linker):
def
build_dynamic_module
(
self
):
def
build_dynamic_module
(
self
):
"""
Generate the code for this module, compile it, return the imported dynamic module
.
"""
Return a cmodule.DynamicModule instance full of the code for our env
.
"""
"""
self
.
code_gen
()
self
.
code_gen
()
module_name
=
self
.
hash
module_name
=
self
.
hash
cthunk
=
object
()
# dummy so weave can get the type
cthunk
=
object
()
# dummy so weave can get the type
##TODO: REMOVE ME
mod
=
cmodule
.
DynamicModule
(
module_name
)
mod
=
cmodule
.
DynamicModule
(
module_name
)
# The code of instantiate
# The code of instantiate
...
@@ -931,7 +940,7 @@ class CLinker(link.Linker):
...
@@ -931,7 +940,7 @@ class CLinker(link.Linker):
orphd
=
[[
orphan
.
data
]
for
orphan
in
self
.
orphans
]
orphd
=
[[
orphan
.
data
]
for
orphan
in
self
.
orphans
]
ret
=
module
.
instantiate
(
error_storage
,
*
(
in_storage
+
out_storage
+
orphd
))
ret
=
module
.
instantiate
(
error_storage
,
*
(
in_storage
+
out_storage
+
orphd
))
return
ret
return
ret
def
instantiate_code
(
self
,
n_args
):
def
instantiate_code
(
self
,
n_args
):
...
@@ -998,6 +1007,11 @@ class OpWiseCLinker(link.LocalLinker):
...
@@ -998,6 +1007,11 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling can contain a list of Variables that belong to the env.
no_recycling can contain a list of Variables that belong to the env.
If a Variable is in no_recycling, CLinker will clear the output storage
If a Variable is in no_recycling, CLinker will clear the output storage
associated to it prior to computation (to avoid reusing it).
associated to it prior to computation (to avoid reusing it).
:note: This is in a sense the 'default' linker for Theano. The overhead of using the
OpWiseCLinker as compared with the CLinker is only noticeable for graphs of very small
tensors (such as 20 elements or less)
"""
"""
__cache__
=
{}
__cache__
=
{}
...
@@ -1044,6 +1058,8 @@ class OpWiseCLinker(link.LocalLinker):
...
@@ -1044,6 +1058,8 @@ class OpWiseCLinker(link.LocalLinker):
try
:
try
:
e
=
Env
(
*
graph
.
clone
(
node
.
inputs
,
node
.
outputs
))
e
=
Env
(
*
graph
.
clone
(
node
.
inputs
,
node
.
outputs
))
# TODO: 20090926 Replace this code with th cl = CLinker().... line. Trust
# ModuleCache for cache mechanism.
if
any
(
isinstance
(
input
,
graph
.
Value
)
for
input
in
node
.
inputs
):
if
any
(
isinstance
(
input
,
graph
.
Value
)
for
input
in
node
.
inputs
):
desc
=
None
desc
=
None
else
:
else
:
...
...
theano/gof/cmodule.py
浏览文件 @
f9a67241
...
@@ -10,16 +10,12 @@ _logger=logging.getLogger("theano.gof.cmodule")
...
@@ -10,16 +10,12 @@ _logger=logging.getLogger("theano.gof.cmodule")
_logger
.
setLevel
(
logging
.
WARN
)
_logger
.
setLevel
(
logging
.
WARN
)
def
error
(
*
args
):
def
error
(
*
args
):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
error
(
"ERROR: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
error
(
"ERROR: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
warning
(
*
args
):
def
warning
(
*
args
):
#sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
warning
(
"WARNING: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
warning
(
"WARNING: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
info
(
*
args
):
def
info
(
*
args
):
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
info
(
"INFO: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
info
(
"INFO: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
def
debug
(
*
args
):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
debug
(
"DEBUG: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
debug
(
"DEBUG: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
METH_VARARGS
=
"METH_VARARGS"
METH_VARARGS
=
"METH_VARARGS"
...
@@ -470,7 +466,8 @@ class ModuleCache(object):
...
@@ -470,7 +466,8 @@ class ModuleCache(object):
"""
"""
Clear all the elements of the cache
Clear all the elements of the cache
"""
"""
return
self
.
clear_old
(
-
1.0
)
self
.
clear_old
(
-
1.0
)
self
.
clear_unversioned
()
def
clear_unversioned
(
self
):
def
clear_unversioned
(
self
):
"""Delete unversioned dynamic modules from the internal dictionaries and from the
"""Delete unversioned dynamic modules from the internal dictionaries and from the
...
@@ -493,6 +490,17 @@ class ModuleCache(object):
...
@@ -493,6 +490,17 @@ class ModuleCache(object):
info
(
"clear_unversioned removing cache dir"
,
parent
)
info
(
"clear_unversioned removing cache dir"
,
parent
)
_rmtree
(
parent
)
_rmtree
(
parent
)
for
filename
in
os
.
listdir
(
self
.
dirname
):
if
filename
.
startswith
(
'tmp'
):
try
:
open
(
os
.
path
.
join
(
self
.
dirname
,
filename
,
'key.pkl'
))
.
close
()
has_key
=
True
except
IOError
:
has_key
=
False
if
not
has_key
:
info
(
"clear_unversioned removing cache dir"
,
filename
)
_rmtree
(
os
.
path
.
join
(
self
.
dirname
,
filename
))
def
_on_atexit
(
self
):
def
_on_atexit
(
self
):
self
.
refresh
()
self
.
refresh
()
self
.
clear_old
()
self
.
clear_old
()
...
@@ -514,6 +522,8 @@ def get_module_cache(dirname, force_fresh=None):
...
@@ -514,6 +522,8 @@ def get_module_cache(dirname, force_fresh=None):
if
_module_cache
is
None
:
if
_module_cache
is
None
:
_module_cache
=
ModuleCache
(
dirname
,
force_fresh
=
force_fresh
)
_module_cache
=
ModuleCache
(
dirname
,
force_fresh
=
force_fresh
)
atexit
.
register
(
_module_cache
.
_on_atexit
)
atexit
.
register
(
_module_cache
.
_on_atexit
)
if
_module_cache
.
dirname
!=
dirname
:
warning
(
"Returning module cache instance with different dirname than you requested"
)
return
_module_cache
return
_module_cache
def
get_lib_extension
():
def
get_lib_extension
():
...
...
theano/scalar/basic.py
浏览文件 @
f9a67241
...
@@ -187,28 +187,28 @@ class Scalar(Type):
...
@@ -187,28 +187,28 @@ class Scalar(Type):
};
};
"""
"""
operator_eq
=
"""
operator_eq
=
"""
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const npy_int8 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<npy_int8>
(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; }
{ this->real=y; this->imag=0; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const npy_int16 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<npy_int16>
(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; }
{ this->real=y; this->imag=0; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const npy_int32 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<npy_int32>
(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; }
{ this->real=y; this->imag=0; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const npy_int64 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<npy_int64>
(const npy_int64 & y)
{ this->real=y; this->imag=0; return *this; }
{ this->real=y; this->imag=0; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const npy_float32 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<npy_float32>
(const npy_float32 & y)
{ this->real=y; this->imag=0; return *this; }
{ this->real=y; this->imag=0; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const npy_float64 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<npy_float64>
(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; }
{ this->real=y; this->imag=0; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const theano_complex128 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<theano_complex128>
(const theano_complex128 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
{ this->real=y.real; this->imag=y.imag; return *this; }
template <>
%(mytype)
s &
%(mytype)
s::operator
=
(const theano_complex64 & y)
template <>
%(mytype)
s &
%(mytype)
s::operator
=<theano_complex64>
(const theano_complex64 & y)
{ this->real=y.real; this->imag=y.imag; return *this; }
{ this->real=y.real; this->imag=y.imag; return *this; }
"""
"""
...
@@ -219,7 +219,8 @@ class Scalar(Type):
...
@@ -219,7 +219,8 @@ class Scalar(Type):
+
operator_eq
%
dict
(
mytype
=
'theano_complex64'
)
+
operator_eq
%
dict
(
mytype
=
'theano_complex64'
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
3
,)
#explicit T given in specialization of operator= lines. This makes it compile with open64
#2,
int8
=
Scalar
(
'int8'
)
int8
=
Scalar
(
'int8'
)
...
@@ -666,10 +667,10 @@ class Mul(ScalarOp):
...
@@ -666,10 +667,10 @@ class Mul(ScalarOp):
retval
=
[]
retval
=
[]
for
input
in
inputs
:
for
input
in
inputs
:
if
input
.
type
in
grad_types
:
if
input
.
type
in
grad_types
:
retval
+=
[
mul
(
*
([
gz
]
+
utils
.
difference
(
inputs
,
[
input
]))
)]
retval
+=
[
cast
(
mul
(
*
([
gz
]
+
utils
.
difference
(
inputs
,
[
input
]))),
input
.
type
.
dtype
)]
else
:
else
:
retval
+=
[
None
]
retval
+=
[
None
]
return
retval
return
retval
#return [(mul(*([gz] + utils.difference(inputs, [input])))
#return [(mul(*([gz] + utils.difference(inputs, [input])))
...
@@ -849,6 +850,8 @@ class Cast(UnaryScalarOp):
...
@@ -849,6 +850,8 @@ class Cast(UnaryScalarOp):
super
(
Cast
,
self
)
.
__init__
(
specific_out
(
o_type
),
name
=
name
)
super
(
Cast
,
self
)
.
__init__
(
specific_out
(
o_type
),
name
=
name
)
self
.
o_type
=
o_type
self
.
o_type
=
o_type
self
.
ctor
=
getattr
(
numpy
,
o_type
.
dtype
)
self
.
ctor
=
getattr
(
numpy
,
o_type
.
dtype
)
def
__str__
(
self
):
return
'
%
s{
%
s}'
%
(
self
.
__class__
.
__name__
,
self
.
o_type
.
dtype
)
def
impl
(
self
,
input
):
def
impl
(
self
,
input
):
return
self
.
ctor
(
input
)
return
self
.
ctor
(
input
)
def
c_code
(
self
,
node
,
name
,
(
x
,
),
(
z
,
),
sub
):
def
c_code
(
self
,
node
,
name
,
(
x
,
),
(
z
,
),
sub
):
...
...
theano/tensor/basic.py
浏览文件 @
f9a67241
...
@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
...
@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
import
__builtin__
import
__builtin__
import
sys
# for sys.maxint
import
sys
# for sys.maxint
import
os
# for getenv THEANO_CMP_SLOPPY
import
traceback
#for overriding Op.__call__
import
traceback
#for overriding Op.__call__
if
sys
.
version_info
>=
(
2
,
5
):
if
sys
.
version_info
>=
(
2
,
5
):
import
functools
import
functools
...
@@ -199,20 +200,33 @@ def _wrap_tensor_into_member(x):
...
@@ -199,20 +200,33 @@ def _wrap_tensor_into_member(x):
return
compile
.
module
.
Member
(
constant
(
x
))
return
compile
.
module
.
Member
(
constant
(
x
))
compile
.
module
.
register_wrapper
(
_obj_is_wrappable_as_tensor
,
_wrap_tensor_into_member
)
compile
.
module
.
register_wrapper
(
_obj_is_wrappable_as_tensor
,
_wrap_tensor_into_member
)
#If you change those value in test don't forget to put them back when the test end.
if
int
(
os
.
getenv
(
'THEANO_CMP_SLOPPY'
,
0
)):
#Don't forget the case when the test fail.
# This environment variable is a quick-and-dirty way to get low-precision comparisons.
float_atol
=
1e-5
# For a more precise setting of these tolerances set them explicitly in your user code by
float_rtol
=
1e-3
# Sensible??
# assigning, for example, "theano.tensor.basic.float32_atol = ..."
float32_atol
=
1e-4
float32_rtol
=
1e-3
float64_rtol
=
1e-4
float64_atol
=
1e-3
else
:
#If you change those value in test don't forget to put them back when the test end.
#Don't forget the case when the test fail.
float32_atol
=
1e-5
float32_rtol
=
1e-3
# defaults in numpy.allclose
float64_rtol
=
1.0000000000000001e-05
float64_atol
=
1e-8
def
_allclose
(
a
,
b
):
def
_allclose
(
a
,
b
):
narrow
=
'float32'
,
'complex64'
narrow
=
'float32'
,
'complex64'
if
(
str
(
a
.
dtype
)
in
narrow
)
or
(
str
(
b
.
dtype
)
in
narrow
):
if
(
str
(
a
.
dtype
)
in
narrow
)
or
(
str
(
b
.
dtype
)
in
narrow
):
atol
=
float_atol
atol
=
float32_atol
rtol
=
float_rtol
rtol
=
float32_rtol
return
numpy
.
allclose
(
a
,
b
,
atol
=
atol
,
rtol
=
rtol
)
else
:
else
:
# keep defaults of in numpy.allclose
atol
=
float64_atol
return
numpy
.
allclose
(
a
,
b
)
rtol
=
float64_rtol
return
numpy
.
allclose
(
a
,
b
,
atol
=
atol
,
rtol
=
rtol
)
class
TensorType
(
Type
):
class
TensorType
(
Type
):
"""Symbolic `Type` representing a numpy.ndarray value."""
"""Symbolic `Type` representing a numpy.ndarray value."""
...
@@ -756,13 +770,29 @@ class TensorVariable(Variable, _tensor_py_operators):
...
@@ -756,13 +770,29 @@ class TensorVariable(Variable, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Variable` class."""
"""Subclass to add the tensor operators to the basic `Variable` class."""
class
TensorConstantSignature
(
tuple
):
class
TensorConstantSignature
(
tuple
):
"""A Signature object for comparing TensorConstant instances
An instance is a pair: (Type instance, ndarray).
"""
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
(
a
,
b
),
(
x
,
y
)
=
self
,
other
try
:
(
t0
,
d0
),
(
t1
,
d1
)
=
self
,
other
except
:
return
False
#N.B. compare shape to ensure no broadcasting in ==
#N.B. compare shape to ensure no broadcasting in ==
return
(
x
==
a
)
and
(
b
.
shape
==
y
.
shape
)
and
(
numpy
.
all
(
b
==
y
))
return
(
t0
==
t1
)
and
(
d0
.
shape
==
d1
.
shape
)
\
and
(
self
.
sum
==
other
.
sum
)
and
(
numpy
.
all
(
d0
==
d1
))
def
__hash__
(
self
):
def
__hash__
(
self
):
a
,
b
=
self
t
,
d
=
self
return
hashtype
(
self
)
^
hash
(
a
)
^
hash
(
b
.
shape
)
return
hashtype
(
self
)
^
hash
(
t
)
^
hash
(
d
.
shape
)
^
hash
(
self
.
sum
)
def
_get_sum
(
self
):
try
:
return
self
.
_sum
except
:
self
.
_sum
=
self
[
1
]
.
sum
()
return
self
.
_sum
sum
=
property
(
_get_sum
)
class
TensorConstant
(
Constant
,
_tensor_py_operators
):
class
TensorConstant
(
Constant
,
_tensor_py_operators
):
"""Subclass to add the tensor operators to the basic `Constant` class.
"""Subclass to add the tensor operators to the basic `Constant` class.
...
@@ -943,6 +973,9 @@ _cast_mapping = {'int8': _convert_to_int8,
...
@@ -943,6 +973,9 @@ _cast_mapping = {'int8': _convert_to_int8,
@constructor
@constructor
def
cast
(
x
,
dtype
):
def
cast
(
x
,
dtype
):
"""Symbolically cast `x` to a Tensor of type `dtype`."""
"""Symbolically cast `x` to a Tensor of type `dtype`."""
_x
=
as_tensor_variable
(
x
)
if
_x
.
type
.
dtype
==
dtype
:
return
_x
if
x
.
type
.
dtype
.
startswith
(
'complex'
)
and
not
dtype
.
startswith
(
'complex'
):
if
x
.
type
.
dtype
.
startswith
(
'complex'
)
and
not
dtype
.
startswith
(
'complex'
):
raise
TypeError
(
'Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()'
)
raise
TypeError
(
'Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()'
)
return
_cast_mapping
[
dtype
](
x
)
return
_cast_mapping
[
dtype
](
x
)
...
@@ -1417,15 +1450,19 @@ def mean(input, axis = None):
...
@@ -1417,15 +1450,19 @@ def mean(input, axis = None):
if
str
(
input
.
dtype
)
.
startswith
(
'int'
):
if
str
(
input
.
dtype
)
.
startswith
(
'int'
):
# we need to cast eventually anyway, and this helps
# we need to cast eventually anyway, and this helps
# to prevents overflow
# to prevents overflow
input
=
c
onvert_to_float64
(
input
)
input
=
c
ast
(
input
,
'float64'
)
s
=
sum
(
input
,
axis
)
s
=
sum
(
input
,
axis
)
shp
=
shape
(
input
)
shp
=
shape
(
input
)
if
input
.
dtype
==
'float32'
:
shp
=
cast
(
shp
,
'float32'
)
if
axis
is
None
:
if
axis
is
None
:
axis
=
range
(
input
.
type
.
ndim
)
axis
=
range
(
input
.
type
.
ndim
)
elif
isinstance
(
axis
,
int
):
elif
isinstance
(
axis
,
int
):
axis
=
[
axis
]
axis
=
[
axis
]
for
i
in
axis
:
for
i
in
axis
:
s
=
s
/
shp
[
i
]
s
=
s
/
shp
[
i
]
if
input
.
dtype
.
startswith
(
'float'
):
assert
input
.
dtype
==
s
.
dtype
return
s
return
s
@constructor
@constructor
...
@@ -1587,6 +1624,12 @@ class Subtensor(Op):
...
@@ -1587,6 +1624,12 @@ class Subtensor(Op):
inputs array is the tensor x, followed by scalar integer variables.
inputs array is the tensor x, followed by scalar integer variables.
@todo: add support for advanced tensor indexing (in Subtensor_dx too).
@todo: add support for advanced tensor indexing (in Subtensor_dx too).
The idx_list is a tuple similar in structure to the sort of key you might expect in numpy's
basic indexing mode. It has one element for each explicitly named dimension. In numpy, the elements
can be either integers or slices containing integers and None. In Subtensor, each element
can additionally be a Scalar instance, and slice components can also be Scalar instances
too.
"""
"""
e_invalid
=
'The index list is longer than the number of dimensions of the tensor.'
e_invalid
=
'The index list is longer than the number of dimensions of the tensor.'
e_subslice
=
'nested slicing is not supported'
e_subslice
=
'nested slicing is not supported'
...
@@ -1707,7 +1750,7 @@ class Subtensor(Op):
...
@@ -1707,7 +1750,7 @@ class Subtensor(Op):
def
grad
(
self
,
inputs
,
(
gz
,)):
def
grad
(
self
,
inputs
,
(
gz
,)):
x
=
inputs
[
0
]
x
=
inputs
[
0
]
rest
=
inputs
[
1
:]
rest
=
inputs
[
1
:]
return
[
Set
Subtensor
(
self
.
idx_list
)(
zeros_like
(
x
),
gz
,
*
rest
)]
+
[
None
]
*
len
(
rest
)
return
[
Inc
Subtensor
(
self
.
idx_list
)(
zeros_like
(
x
),
gz
,
*
rest
)]
+
[
None
]
*
len
(
rest
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
idx_list
==
other
.
idx_list
return
type
(
self
)
==
type
(
other
)
and
self
.
idx_list
==
other
.
idx_list
...
@@ -1794,13 +1837,14 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
...
@@ -1794,13 +1837,14 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
class
Set
Subtensor
(
Op
):
class
Inc
Subtensor
(
Op
):
"""
Set just some elements of a larger TensorType
.
"""
Increment a subtensor
.
This is like numpy's
This is like numpy's
z[i,j,k] = <something>
z[i,j,k]
+
= <something>
It is used internally to implement the gradient on SubTensor.
"""
"""
def
__init__
(
self
,
idx_list
,
inplace
=
False
):
def
__init__
(
self
,
idx_list
,
inplace
=
False
):
...
@@ -1858,7 +1902,7 @@ class SetSubtensor(Op):
...
@@ -1858,7 +1902,7 @@ class SetSubtensor(Op):
broadcastable
=
[
bc
for
p
,
bc
in
zip
(
padded
,
x
.
type
.
broadcastable
)
if
isinstance
(
p
,
slice
)]
broadcastable
=
[
bc
for
p
,
bc
in
zip
(
padded
,
x
.
type
.
broadcastable
)
if
isinstance
(
p
,
slice
)]
if
y
.
type
.
broadcastable
!=
tuple
(
broadcastable
):
if
y
.
type
.
broadcastable
!=
tuple
(
broadcastable
):
raise
TypeError
(
"Invalid broadcastable pattern for y in
Set
Subtensor.make_node"
)
raise
TypeError
(
"Invalid broadcastable pattern for y in
Inc
Subtensor.make_node"
)
input_types
=
Subtensor
.
collapse
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
gof
.
Type
))
input_types
=
Subtensor
.
collapse
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
gof
.
Type
))
if
len
(
inputs
)
!=
len
(
input_types
):
if
len
(
inputs
)
!=
len
(
input_types
):
...
@@ -1890,7 +1934,13 @@ class SetSubtensor(Op):
...
@@ -1890,7 +1934,13 @@ class SetSubtensor(Op):
cdata
=
cdata
[
0
]
cdata
=
cdata
[
0
]
if
not
self
.
inplace
:
if
not
self
.
inplace
:
x
=
x
.
copy
()
x
=
x
.
copy
()
x
.
__setitem__
(
cdata
,
y
)
sub_x
=
x
.
__getitem__
(
cdata
)
if
sub_x
.
shape
:
# we've sliced out an N-D tensor with N > 0
sub_x
+=
y
else
:
# scalar case
x
.
__setitem__
(
cdata
,
sub_x
+
y
)
out
[
0
]
=
x
out
[
0
]
=
x
def
split
(
x
,
splits_size
,
n_splits
,
axis
=
0
):
def
split
(
x
,
splits_size
,
n_splits
,
axis
=
0
):
...
@@ -2543,12 +2593,15 @@ class Dot(Op):
...
@@ -2543,12 +2593,15 @@ class Dot(Op):
def
grad
(
self
,
(
x
,
y
),
(
gz
,)):
def
grad
(
self
,
(
x
,
y
),
(
gz
,)):
if
gz
.
type
.
ndim
==
0
:
if
gz
.
type
.
ndim
==
0
:
return
gz
*
y
,
gz
*
x
rval
=
gz
*
y
,
gz
*
x
if
x
.
type
.
ndim
==
1
and
y
.
type
.
ndim
>
1
:
elif
x
.
type
.
ndim
==
1
and
y
.
type
.
ndim
>
1
:
return
dot
(
gz
,
y
.
T
),
outer
(
x
.
T
,
gz
)
rval
=
dot
(
gz
,
y
.
T
),
outer
(
x
.
T
,
gz
)
if
x
.
type
.
ndim
>
1
and
y
.
type
.
ndim
==
1
:
elif
x
.
type
.
ndim
>
1
and
y
.
type
.
ndim
==
1
:
return
outer
(
gz
,
y
.
T
),
dot
(
x
.
T
,
gz
)
rval
=
outer
(
gz
,
y
.
T
),
dot
(
x
.
T
,
gz
)
return
dot
(
gz
,
y
.
T
),
dot
(
x
.
T
,
gz
)
else
:
rval
=
dot
(
gz
,
y
.
T
),
dot
(
x
.
T
,
gz
)
return
cast
(
rval
[
0
],
x
.
dtype
),
cast
(
rval
[
1
],
y
.
dtype
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"dot"
return
"dot"
dot
=
Dot
()
dot
=
Dot
()
...
...
theano/tensor/blas.py
浏览文件 @
f9a67241
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import
os
,
sys
,
traceback
import
os
,
sys
,
traceback
,
logging
import
numpy
import
numpy
from
theano.gof
import
(
utils
,
Op
,
Apply
,
view_roots
,
PatternSub
,
DestroyHandler
,
from
theano.gof
import
(
utils
,
Op
,
Apply
,
view_roots
,
PatternSub
,
DestroyHandler
,
...
@@ -17,6 +17,13 @@ from theano import compile #to register the optimizer built by this file
...
@@ -17,6 +17,13 @@ from theano import compile #to register the optimizer built by this file
from
theano.tensor.blas_headers
import
cblas_header_text
,
blas_header_text
from
theano.tensor.blas_headers
import
cblas_header_text
,
blas_header_text
_logger
=
logging
.
getLogger
(
'theano.tensor.blas'
)
def
debug
(
*
msg
):
_logger
.
debug
(
' '
.
join
(
str
(
m
)
for
m
in
msg
))
def
info
(
*
msg
):
_logger
.
info
(
' '
.
join
(
str
(
m
)
for
m
in
msg
))
def
warn
(
*
msg
):
_logger
.
warn
(
' '
.
join
(
str
(
m
)
for
m
in
msg
))
def
warning
(
*
msg
):
_logger
.
warning
(
' '
.
join
(
str
(
m
)
for
m
in
msg
))
def
error
(
*
msg
):
_logger
.
error
(
' '
.
join
(
str
(
m
)
for
m
in
msg
))
@utils.memoize
@utils.memoize
def
ldflags
(
libs
=
True
,
flags
=
False
):
def
ldflags
(
libs
=
True
,
flags
=
False
):
"""Return a list of libraries against which an Op's object file should be
"""Return a list of libraries against which an Op's object file should be
...
@@ -655,6 +662,8 @@ def local_dot_to_dot22(node):
...
@@ -655,6 +662,8 @@ def local_dot_to_dot22(node):
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
if
_is_real_matrix
(
x
)
and
y
.
type
==
x
.
type
:
if
_is_real_matrix
(
x
)
and
y
.
type
==
x
.
type
:
return
[
_dot22
(
*
node
.
inputs
)]
return
[
_dot22
(
*
node
.
inputs
)]
else
:
info
(
'Not optimizing dot with inputs'
,
x
,
y
)
else
:
else
:
return
False
return
False
register_specialize
(
local_dot_to_dot22
)
register_specialize
(
local_dot_to_dot22
)
...
...
theano/tensor/opt.py
浏览文件 @
f9a67241
...
@@ -16,7 +16,6 @@ import itertools
...
@@ -16,7 +16,6 @@ import itertools
import
sys
import
sys
from
theano
import
compile
#to register the optimizer built by this file
from
theano
import
compile
#to register the optimizer built by this file
from
theano.compile.debugmode
import
_debugprint
from
theano.gof.python25
import
any
,
all
from
theano.gof.python25
import
any
,
all
# Utilities
# Utilities
...
@@ -292,13 +291,73 @@ def local_subtensor_make_vector(node):
...
@@ -292,13 +291,73 @@ def local_subtensor_make_vector(node):
register_canonicalize
(
local_subtensor_make_vector
)
register_canonicalize
(
local_subtensor_make_vector
)
@register_canonicalize
@gof.local_optimizer
([
None
])
def
local_IncSubtensor_serialize
(
node
):
"""
When using Subtensor, gradient graphs can be ugly.
If we ask for grad(f(a[0]), a), we are going to get something like
IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
This might be ugly, but at least it's as fast as you could want. If we ask for
grad(f(a[0], a[1], a[2]), a), it's much worse...
Elemwise{Add}
IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1])
IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2])
This is much worse because this time we have to produce 3 matrices the size of 'a', just so
we can add them together.
This Op rearranges IncSubtensor's that all work on the same initial argument (here,
Elemwise{second}(a,0)) into a chain. The advantage of the chain structure is that each one
can be optimized later in the pipeline to operate inplace.
Ideally, the op will do something like this:
#
# add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b), c), d)
"""
def
movable
(
i
):
# Return True iff this is a incsubtensor that we can move
return
i
.
owner
\
and
isinstance
(
i
.
owner
.
op
,
T
.
IncSubtensor
)
\
and
i
.
type
==
o_type
\
and
len
(
i
.
clients
)
==
1
if
node
.
op
==
T
.
add
:
o_type
=
node
.
outputs
[
0
]
.
type
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
if
movable_inputs
:
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
\
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
new_add
=
T
.
add
(
*
new_inputs
)
# stack up the new incsubtensors
tip
=
new_add
for
mi
in
movable_inputs
:
assert
tip
.
type
==
o_type
assert
tip
.
type
==
mi
.
owner
.
inputs
[
0
]
.
type
tip
=
mi
.
owner
.
op
(
tip
,
*
mi
.
owner
.
inputs
[
1
:])
return
[
tip
]
#print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
#after priority 50 Destructive inplace operations
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
#gemm is the first one now, at priority 70
@gof.local_optimizer
([
None
])
@gof.local_optimizer
([
None
])
def
local_inplace_setsubtensor
(
node
):
def
local_inplace_setsubtensor
(
node
):
if
isinstance
(
node
.
op
,
T
.
Set
Subtensor
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
T
.
Inc
Subtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
T
.
Set
Subtensor
(
node
.
op
.
idx_list
,
inplace
=
True
)
new_op
=
T
.
Inc
Subtensor
(
node
.
op
.
idx_list
,
inplace
=
True
)
new_node
=
new_op
(
*
node
.
inputs
)
new_node
=
new_op
(
*
node
.
inputs
)
return
[
new_node
]
return
[
new_node
]
return
False
return
False
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
f9a67241
...
@@ -852,7 +852,10 @@ class T_subtensor(unittest.TestCase):
...
@@ -852,7 +852,10 @@ class T_subtensor(unittest.TestCase):
n
=
as_tensor_variable
(
data
)
n
=
as_tensor_variable
(
data
)
t
=
n
[
1
,
0
]
t
=
n
[
1
,
0
]
gn
=
grad
(
sum
(
exp
(
t
)),
n
)
gn
=
grad
(
sum
(
exp
(
t
)),
n
)
gval
=
eval_outputs
([
gn
])
f
=
function
([],
gn
,
mode
=
None
)
print
'toposort'
,
f
.
maker
.
env
.
toposort
()
gval
=
f
()
print
gval
good
=
numpy
.
zeros_like
(
data
)
good
=
numpy
.
zeros_like
(
data
)
good
[
1
,
0
]
=
numpy
.
exp
(
data
[
1
,
0
])
good
[
1
,
0
]
=
numpy
.
exp
(
data
[
1
,
0
])
self
.
failUnless
(
numpy
.
all
(
gval
==
good
),
(
gval
,
good
))
self
.
failUnless
(
numpy
.
all
(
gval
==
good
),
(
gval
,
good
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论