Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
86dbc392
提交
86dbc392
authored
10月 21, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Apply pyupgrade to top-level modules in theano package
上级
d6477d58
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
111 行增加
和
148 行删除
+111
-148
_version.py
theano/_version.py
+8
-8
compat.py
theano/compat.py
+27
-42
configdefaults.py
theano/configdefaults.py
+13
-16
configparser.py
theano/configparser.py
+13
-15
gradient.py
theano/gradient.py
+20
-28
ifelse.py
theano/ifelse.py
+8
-12
pathparse.py
theano/pathparse.py
+1
-1
printing.py
theano/printing.py
+18
-23
raise_op.py
theano/raise_op.py
+1
-1
updates.py
theano/updates.py
+2
-2
没有找到文件。
theano/_version.py
浏览文件 @
86dbc392
...
@@ -84,7 +84,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
...
@@ -84,7 +84,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
stderr
=
(
subprocess
.
PIPE
if
hide_stderr
else
None
),
stderr
=
(
subprocess
.
PIPE
if
hide_stderr
else
None
),
)
)
break
break
except
Environment
Error
:
except
OS
Error
:
e
=
sys
.
exc_info
()[
1
]
e
=
sys
.
exc_info
()[
1
]
if
e
.
errno
==
errno
.
ENOENT
:
if
e
.
errno
==
errno
.
ENOENT
:
continue
continue
...
@@ -94,7 +94,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
...
@@ -94,7 +94,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
return
None
,
None
return
None
,
None
else
:
else
:
if
verbose
:
if
verbose
:
print
(
"unable to find command, tried
%
s"
%
(
commands
,
))
print
(
"unable to find command, tried
{}"
.
format
(
commands
))
return
None
,
None
return
None
,
None
stdout
=
p
.
communicate
()[
0
]
.
strip
()
.
decode
()
stdout
=
p
.
communicate
()[
0
]
.
strip
()
.
decode
()
if
p
.
returncode
!=
0
:
if
p
.
returncode
!=
0
:
...
@@ -145,7 +145,7 @@ def git_get_keywords(versionfile_abs):
...
@@ -145,7 +145,7 @@ def git_get_keywords(versionfile_abs):
# _version.py.
# _version.py.
keywords
=
{}
keywords
=
{}
try
:
try
:
f
=
open
(
versionfile_abs
,
"r"
)
f
=
open
(
versionfile_abs
)
for
line
in
f
.
readlines
():
for
line
in
f
.
readlines
():
if
line
.
strip
()
.
startswith
(
"git_refnames ="
):
if
line
.
strip
()
.
startswith
(
"git_refnames ="
):
mo
=
re
.
search
(
r'=\s*"(.*)"'
,
line
)
mo
=
re
.
search
(
r'=\s*"(.*)"'
,
line
)
...
@@ -160,7 +160,7 @@ def git_get_keywords(versionfile_abs):
...
@@ -160,7 +160,7 @@ def git_get_keywords(versionfile_abs):
if
mo
:
if
mo
:
keywords
[
"date"
]
=
mo
.
group
(
1
)
keywords
[
"date"
]
=
mo
.
group
(
1
)
f
.
close
()
f
.
close
()
except
Environment
Error
:
except
OS
Error
:
pass
pass
return
keywords
return
keywords
...
@@ -184,11 +184,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
...
@@ -184,11 +184,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if
verbose
:
if
verbose
:
print
(
"keywords are unexpanded, not using"
)
print
(
"keywords are unexpanded, not using"
)
raise
NotThisMethod
(
"unexpanded keywords, not a git-archive tarball"
)
raise
NotThisMethod
(
"unexpanded keywords, not a git-archive tarball"
)
refs
=
set
([
r
.
strip
()
for
r
in
refnames
.
strip
(
"()"
)
.
split
(
","
)])
refs
=
{
r
.
strip
()
for
r
in
refnames
.
strip
(
"()"
)
.
split
(
","
)}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG
=
"tag: "
TAG
=
"tag: "
tags
=
set
([
r
[
len
(
TAG
)
:]
for
r
in
refs
if
r
.
startswith
(
TAG
)])
tags
=
{
r
[
len
(
TAG
)
:]
for
r
in
refs
if
r
.
startswith
(
TAG
)}
if
not
tags
:
if
not
tags
:
# Either we're using git < 1.8.3, or there really are no tags. We use
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
# a heuristic: assume all version tags have a digit. The old git %d
...
@@ -197,7 +197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
...
@@ -197,7 +197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
# "stabilization", as well as "HEAD" and "master".
tags
=
set
([
r
for
r
in
refs
if
re
.
search
(
r"\d"
,
r
)])
tags
=
{
r
for
r
in
refs
if
re
.
search
(
r"\d"
,
r
)}
if
verbose
:
if
verbose
:
print
(
"discarding '
%
s', no digits"
%
","
.
join
(
refs
-
tags
))
print
(
"discarding '
%
s', no digits"
%
","
.
join
(
refs
-
tags
))
if
verbose
:
if
verbose
:
...
@@ -300,7 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
...
@@ -300,7 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if
verbose
:
if
verbose
:
fmt
=
"tag '
%
s' doesn't start with prefix '
%
s'"
fmt
=
"tag '
%
s' doesn't start with prefix '
%
s'"
print
(
fmt
%
(
full_tag
,
tag_prefix
))
print
(
fmt
%
(
full_tag
,
tag_prefix
))
pieces
[
"error"
]
=
"tag '
%
s' doesn't start with prefix '
%
s'"
%
(
pieces
[
"error"
]
=
"tag '
{}' doesn't start with prefix '{}'"
.
format
(
full_tag
,
full_tag
,
tag_prefix
,
tag_prefix
,
)
)
...
...
theano/compat.py
浏览文件 @
86dbc392
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
collections
import
OrderedDict
from
collections
import
OrderedDict
# Python 3.x compatibility
# Python 3.x compatibility
from
six
import
PY3
,
BytesIO
,
b
,
next
from
six
import
PY3
,
BytesIO
,
b
from
six.moves
import
configparser
from
six.moves
import
configparser
from
six.moves
import
reload_module
as
reload
from
six.moves
import
reload_module
as
reload
...
@@ -24,59 +24,44 @@ except ImportError:
...
@@ -24,59 +24,44 @@ except ImportError:
__all__
=
[
"PY3"
,
"b"
,
"BytesIO"
,
"next"
,
"configparser"
,
"reload"
]
__all__
=
[
"PY3"
,
"b"
,
"BytesIO"
,
"next"
,
"configparser"
,
"reload"
]
if
PY3
:
from
operator
import
truediv
as
operator_div
from
operator
import
truediv
as
operator_div
# In python 3.x, when an exception is reraised it saves original
# exception in its args, therefore in order to find the actual
# message, we need to unpack arguments recursively.
def
exc_message
(
e
):
msg
=
e
.
args
[
0
]
if
isinstance
(
msg
,
Exception
):
return
exc_message
(
msg
)
return
msg
def
cmp
(
x
,
y
):
# In python 3.x, when an exception is reraised it saves original
"""Return -1 if x < y, 0 if x == y, 1 if x > y."""
# exception in its args, therefore in order to find the actual
return
(
x
>
y
)
-
(
x
<
y
)
# message, we need to unpack arguments recursively.
def
exc_message
(
e
):
msg
=
e
.
args
[
0
]
if
isinstance
(
msg
,
Exception
):
return
exc_message
(
msg
)
return
msg
def
get_unbound_function
(
unbound
):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if
hasattr
(
unbound
,
"__func__"
):
return
unbound
.
__func__
return
unbound
def
decode
(
x
):
def
cmp
(
x
,
y
):
return
x
.
decode
()
"""Return -1 if x < y, 0 if x == y, 1 if x > y."""
return
(
x
>
y
)
-
(
x
<
y
)
def
decode_iter
(
itr
):
for
x
in
itr
:
yield
x
.
decode
()
def
decode_with
(
x
,
encoding
):
def
get_unbound_function
(
unbound
):
return
x
.
decode
(
encoding
)
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if
hasattr
(
unbound
,
"__func__"
):
return
unbound
.
__func__
return
unbound
else
:
def
decode
(
x
)
:
from
operator
import
div
as
operator_div
return
x
.
decode
()
from
six
import
get_unbound_function
def
exc_message
(
e
):
def
decode_iter
(
itr
):
return
e
[
0
]
for
x
in
itr
:
yield
x
.
decode
()
cmp
=
cmp
def
decode
(
x
):
def
decode_with
(
x
,
encoding
):
return
x
return
x
.
decode
(
encoding
)
def
decode_iter
(
x
):
return
x
def
decode_with
(
x
,
encoding
):
return
x
__all__
+=
[
__all__
+=
[
...
...
theano/configdefaults.py
浏览文件 @
86dbc392
...
@@ -10,7 +10,6 @@ import textwrap
...
@@ -10,7 +10,6 @@ import textwrap
import
warnings
import
warnings
import
numpy
as
np
import
numpy
as
np
from
six
import
string_types
import
theano
import
theano
from
theano.compat
import
maybe_add_to_os_environ_pathlist
from
theano.compat
import
maybe_add_to_os_environ_pathlist
...
@@ -142,18 +141,16 @@ class DeviceParam(ConfigParam):
...
@@ -142,18 +141,16 @@ class DeviceParam(ConfigParam):
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
(
'Invalid value ("
%
s") for configuration '
'Invalid value ("
%
s") for configuration '
'variable "
%
s". Valid options start with '
'variable "
%
s". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
%
(
val
,
self
.
fullname
)
'one of "cpu", "opencl" or "cuda".'
%
(
val
,
self
.
fullname
)
)
)
)
over
=
kwargs
.
get
(
"allow_override"
,
True
)
over
=
kwargs
.
get
(
"allow_override"
,
True
)
super
(
DeviceParam
,
self
)
.
__init__
(
default
,
filter
,
over
)
super
()
.
__init__
(
default
,
filter
,
over
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s (
%
s, opencl*, cuda*) "
%
(
self
.
fullname
,
self
.
default
)
return
"
{} ({}, opencl*, cuda*) "
.
format
(
self
.
fullname
,
self
.
default
)
AddConfigVar
(
AddConfigVar
(
...
@@ -211,13 +208,13 @@ class ContextsParam(ConfigParam):
...
@@ -211,13 +208,13 @@ class ContextsParam(ConfigParam):
for
v
in
val
.
split
(
";"
):
for
v
in
val
.
split
(
";"
):
s
=
v
.
split
(
"->"
)
s
=
v
.
split
(
"->"
)
if
len
(
s
)
!=
2
:
if
len
(
s
)
!=
2
:
raise
ValueError
(
"Malformed context map:
%
s"
%
(
v
,
))
raise
ValueError
(
"Malformed context map:
{}"
.
format
(
v
))
if
(
if
(
s
[
0
]
==
"cpu"
s
[
0
]
==
"cpu"
or
s
[
0
]
.
startswith
(
"cuda"
)
or
s
[
0
]
.
startswith
(
"cuda"
)
or
s
[
0
]
.
startswith
(
"opencl"
)
or
s
[
0
]
.
startswith
(
"opencl"
)
):
):
raise
ValueError
(
"Cannot use
%
s as context name"
%
(
s
[
0
],
))
raise
ValueError
(
"Cannot use
{} as context name"
.
format
(
s
[
0
]
))
return
val
return
val
ConfigParam
.
__init__
(
self
,
""
,
filter
,
False
)
ConfigParam
.
__init__
(
self
,
""
,
filter
,
False
)
...
@@ -1409,7 +1406,7 @@ AddConfigVar(
...
@@ -1409,7 +1406,7 @@ AddConfigVar(
def
is_valid_check_preallocated_output_param
(
param
):
def
is_valid_check_preallocated_output_param
(
param
):
if
not
isinstance
(
param
,
str
ing_types
):
if
not
isinstance
(
param
,
str
):
return
False
return
False
valid
=
[
valid
=
[
"initial"
,
"initial"
,
...
@@ -1821,7 +1818,7 @@ def default_blas_ldflags():
...
@@ -1821,7 +1818,7 @@ def default_blas_ldflags():
# we just pass the whole ldflags as the -l
# we just pass the whole ldflags as the -l
# options part.
# options part.
[
[
"-L
%
s
%
s
%
s"
%
(
path_wrapper
,
l
,
path_wrapper
)
"-L
{}{}{}"
.
format
(
path_wrapper
,
l
,
path_wrapper
)
for
l
in
blas_info
.
get
(
"library_dirs"
,
[])
for
l
in
blas_info
.
get
(
"library_dirs"
,
[])
]
]
+
[
"-l
%
s"
%
l
for
l
in
blas_info
.
get
(
"libraries"
,
[])]
+
[
"-l
%
s"
%
l
for
l
in
blas_info
.
get
(
"libraries"
,
[])]
...
@@ -1902,7 +1899,7 @@ def try_blas_flag(flags):
...
@@ -1902,7 +1899,7 @@ def try_blas_flag(flags):
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
cflags
.
extend
(
cflags
.
extend
(
[
[
"-L
%
s
%
s
%
s"
%
(
path_wrapper
,
d
,
path_wrapper
)
"-L
{}{}{}"
.
format
(
path_wrapper
,
d
,
path_wrapper
)
for
d
in
theano
.
gof
.
cmodule
.
std_lib_dirs
()
for
d
in
theano
.
gof
.
cmodule
.
std_lib_dirs
()
]
]
)
)
...
@@ -2311,11 +2308,11 @@ def filter_compiledir(path):
...
@@ -2311,11 +2308,11 @@ def filter_compiledir(path):
if
not
os
.
path
.
exists
(
init_file
):
if
not
os
.
path
.
exists
(
init_file
):
try
:
try
:
open
(
init_file
,
"w"
)
.
close
()
open
(
init_file
,
"w"
)
.
close
()
except
IO
Error
as
e
:
except
OS
Error
as
e
:
if
os
.
path
.
exists
(
init_file
):
if
os
.
path
.
exists
(
init_file
):
pass
# has already been created
pass
# has already been created
else
:
else
:
e
.
args
+=
(
"
%
s exist?
%
s"
%
(
path
,
os
.
path
.
exists
(
path
)),)
e
.
args
+=
(
"
{} exist? {}"
.
format
(
path
,
os
.
path
.
exists
(
path
)),)
raise
raise
return
path
return
path
...
@@ -2390,4 +2387,4 @@ AddConfigVar(
...
@@ -2390,4 +2387,4 @@ AddConfigVar(
# Check if there are remaining flags provided by the user through THEANO_FLAGS.
# Check if there are remaining flags provided by the user through THEANO_FLAGS.
for
key
in
THEANO_FLAGS_DICT
.
keys
():
for
key
in
THEANO_FLAGS_DICT
.
keys
():
warnings
.
warn
(
"Theano does not recognise this flag: {
0
}"
.
format
(
key
))
warnings
.
warn
(
"Theano does not recognise this flag: {}"
.
format
(
key
))
theano/configparser.py
浏览文件 @
86dbc392
...
@@ -10,7 +10,7 @@ import sys
...
@@ -10,7 +10,7 @@ import sys
import
warnings
import
warnings
from
functools
import
wraps
from
functools
import
wraps
from
six
import
PY3
,
StringIO
,
string_types
from
six
import
PY3
,
StringIO
import
theano
import
theano
from
theano.compat
import
configparser
as
ConfigParser
from
theano.compat
import
configparser
as
ConfigParser
...
@@ -95,7 +95,7 @@ theano_raw_cfg = ConfigParser.RawConfigParser()
...
@@ -95,7 +95,7 @@ theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg
.
read
(
config_files
)
theano_raw_cfg
.
read
(
config_files
)
class
change_flags
(
object
)
:
class
change_flags
:
"""
"""
Use this as a decorator or context manager to change the value of
Use this as a decorator or context manager to change the value of
Theano config variables.
Theano config variables.
...
@@ -204,12 +204,12 @@ def get_config_hash():
...
@@ -204,12 +204,12 @@ def get_config_hash():
)
)
return
theano
.
gof
.
utils
.
hash_from_code
(
return
theano
.
gof
.
utils
.
hash_from_code
(
"
\n
"
.
join
(
"
\n
"
.
join
(
[
"
%
s =
%
s"
%
(
cv
.
fullname
,
cv
.
__get__
(
True
,
None
))
for
cv
in
all_opts
]
[
"
{} = {}"
.
format
(
cv
.
fullname
,
cv
.
__get__
(
True
,
None
))
for
cv
in
all_opts
]
)
)
)
)
class
TheanoConfigParser
(
object
)
:
class
TheanoConfigParser
:
# properties are installed by AddConfigVar
# properties are installed by AddConfigVar
_i_am_a_config_class
=
True
_i_am_a_config_class
=
True
...
@@ -276,7 +276,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
...
@@ -276,7 +276,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
if
not
hasattr
(
root
,
sections
[
0
]):
if
not
hasattr
(
root
,
sections
[
0
]):
# every internal node in the config tree is an instance of its own
# every internal node in the config tree is an instance of its own
# unique class
# unique class
class
SubObj
(
object
)
:
class
SubObj
:
_i_am_a_config_class
=
True
_i_am_a_config_class
=
True
setattr
(
root
.
__class__
,
sections
[
0
],
SubObj
())
setattr
(
root
.
__class__
,
sections
[
0
],
SubObj
())
...
@@ -312,7 +312,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
...
@@ -312,7 +312,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
_config_var_list
.
append
(
configparam
)
_config_var_list
.
append
(
configparam
)
class
ConfigParam
(
object
)
:
class
ConfigParam
:
def
__init__
(
self
,
default
,
filter
=
None
,
allow_override
=
True
):
def
__init__
(
self
,
default
,
filter
=
None
,
allow_override
=
True
):
"""
"""
If allow_override is False, we can't change the value after the import
If allow_override is False, we can't change the value after the import
...
@@ -368,7 +368,7 @@ class EnumStr(ConfigParam):
...
@@ -368,7 +368,7 @@ class EnumStr(ConfigParam):
# All options should be strings
# All options should be strings
for
val
in
self
.
all
:
for
val
in
self
.
all
:
if
not
isinstance
(
val
,
str
ing_types
):
if
not
isinstance
(
val
,
str
):
raise
ValueError
(
raise
ValueError
(
"Valid values for an EnumStr parameter "
"should be strings"
,
"Valid values for an EnumStr parameter "
"should be strings"
,
val
,
val
,
...
@@ -384,17 +384,15 @@ class EnumStr(ConfigParam):
...
@@ -384,17 +384,15 @@ class EnumStr(ConfigParam):
return
val
return
val
else
:
else
:
raise
ValueError
(
raise
ValueError
(
(
'Invalid value ("
%
s") for configuration variable "
%
s". '
'Invalid value ("
%
s") for configuration variable "
%
s". '
"Valid options are
%
s"
%
(
val
,
self
.
fullname
,
self
.
all
)
"Valid options are
%
s"
%
(
val
,
self
.
fullname
,
self
.
all
)
)
)
)
over
=
kwargs
.
get
(
"allow_override"
,
True
)
over
=
kwargs
.
get
(
"allow_override"
,
True
)
super
(
EnumStr
,
self
)
.
__init__
(
default
,
filter
,
over
)
super
()
.
__init__
(
default
,
filter
,
over
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s (
%
s) "
%
(
self
.
fullname
,
self
.
all
)
return
"
{} ({}) "
.
format
(
self
.
fullname
,
self
.
all
)
class
TypedParam
(
ConfigParam
):
class
TypedParam
(
ConfigParam
):
...
@@ -414,10 +412,10 @@ class TypedParam(ConfigParam):
...
@@ -414,10 +412,10 @@ class TypedParam(ConfigParam):
)
)
return
cast_val
return
cast_val
super
(
TypedParam
,
self
)
.
__init__
(
default
,
filter
,
allow_override
=
allow_override
)
super
()
.
__init__
(
default
,
filter
,
allow_override
=
allow_override
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s (
%
s) "
%
(
self
.
fullname
,
self
.
mytype
)
return
"
{} ({}) "
.
format
(
self
.
fullname
,
self
.
mytype
)
def
StrParam
(
default
,
is_valid
=
None
,
allow_override
=
True
):
def
StrParam
(
default
,
is_valid
=
None
,
allow_override
=
True
):
...
...
theano/gradient.py
浏览文件 @
86dbc392
...
@@ -130,20 +130,16 @@ class DisconnectedType(theano.gof.type.Type):
...
@@ -130,20 +130,16 @@ class DisconnectedType(theano.gof.type.Type):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
raise
AssertionError
(
raise
AssertionError
(
(
"If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
" a symbolic placeholder."
)
)
)
def
fiter_variable
(
self
,
other
):
def
fiter_variable
(
self
,
other
):
raise
AssertionError
(
raise
AssertionError
(
(
"If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
" a symbolic placeholder."
)
)
)
def
may_share_memory
(
a
,
b
):
def
may_share_memory
(
a
,
b
):
...
@@ -151,11 +147,9 @@ class DisconnectedType(theano.gof.type.Type):
...
@@ -151,11 +147,9 @@ class DisconnectedType(theano.gof.type.Type):
def
value_eq
(
a
,
b
,
force_same_dtype
=
True
):
def
value_eq
(
a
,
b
,
force_same_dtype
=
True
):
raise
AssertionError
(
raise
AssertionError
(
(
"If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
" a symbolic placeholder."
)
)
)
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -846,7 +840,7 @@ def _node_to_pattern(node):
...
@@ -846,7 +840,7 @@ def _node_to_pattern(node):
raise
TypeError
(
raise
TypeError
(
"
%
s.connection_pattern should return"
%
node
.
op
"
%
s.connection_pattern should return"
%
node
.
op
+
" a list of lists, but element
%
d"
%
ii
+
" a list of lists, but element
%
d"
%
ii
+
"is
%
s of type
%
s."
%
(
output_pattern
,
type
(
output_pattern
))
+
"is
{} of type {}."
.
format
(
output_pattern
,
type
(
output_pattern
))
)
)
else
:
else
:
connection_pattern
=
[[
True
for
output
in
node
.
outputs
]
for
ipt
in
node
.
inputs
]
connection_pattern
=
[[
True
for
output
in
node
.
outputs
]
for
ipt
in
node
.
inputs
]
...
@@ -933,7 +927,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
...
@@ -933,7 +927,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# Note: we need to revisit the apply nodes repeatedly, because
# Note: we need to revisit the apply nodes repeatedly, because
# different outputs of the apply node are connected to
# different outputs of the apply node are connected to
# different subsets of the inputs.
# different subsets of the inputs.
accounted_for
=
set
(
[]
)
accounted_for
=
set
()
def
account_for
(
var
):
def
account_for
(
var
):
# Don't visit the same variable twice
# Don't visit the same variable twice
...
@@ -984,7 +978,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
...
@@ -984,7 +978,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# determine which variables have elements of wrt as a true
# determine which variables have elements of wrt as a true
# ancestor. Do this with an upward pass starting from wrt,
# ancestor. Do this with an upward pass starting from wrt,
# following only true connections
# following only true connections
visited
=
set
(
[]
)
visited
=
set
()
def
visit
(
var
):
def
visit
(
var
):
if
var
in
visited
:
if
var
in
visited
:
...
@@ -1458,7 +1452,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
...
@@ -1458,7 +1452,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
grad_dict
[
var
]
=
disconnected_type
()
grad_dict
[
var
]
=
disconnected_type
()
if
cost_name
is
not
None
and
var
.
name
is
not
None
:
if
cost_name
is
not
None
and
var
.
name
is
not
None
:
grad_dict
[
var
]
.
name
=
"(d
%
s/d
%
s)"
%
(
cost_name
,
var
.
name
)
grad_dict
[
var
]
.
name
=
"(d
{}/d{})"
.
format
(
cost_name
,
var
.
name
)
else
:
else
:
# this variable isn't connected to the cost in the
# this variable isn't connected to the cost in the
# computational graph
# computational graph
...
@@ -1494,7 +1488,7 @@ def _float_ones_like(x):
...
@@ -1494,7 +1488,7 @@ def _float_ones_like(x):
return
x
.
ones_like
(
dtype
=
dtype
)
return
x
.
ones_like
(
dtype
=
dtype
)
class
numeric_grad
(
object
)
:
class
numeric_grad
:
"""
"""
Compute the numeric derivative of a scalar-valued function at a particular
Compute the numeric derivative of a scalar-valued function at a particular
point.
point.
...
@@ -1818,13 +1812,11 @@ def verify_grad(
...
@@ -1818,13 +1812,11 @@ def verify_grad(
if
rng
is
None
:
if
rng
is
None
:
raise
TypeError
(
raise
TypeError
(
(
"rng should be a valid instance of "
"rng should be a valid instance of "
"numpy.random.RandomState. You may "
"numpy.random.RandomState. You may "
"want to use tests.unittest"
"want to use tests.unittest"
"_tools.verify_grad instead of "
"_tools.verify_grad instead of "
"theano.gradient.verify_grad."
"theano.gradient.verify_grad."
)
)
)
# We allow input downcast in function, because numeric_grad works in the
# We allow input downcast in function, because numeric_grad works in the
...
@@ -1853,7 +1845,7 @@ def verify_grad(
...
@@ -1853,7 +1845,7 @@ def verify_grad(
if
isinstance
(
o_output
,
list
):
if
isinstance
(
o_output
,
list
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
"cant (yet) autotest gradient of fun "
"with multiple outputs"
)
"cant (yet) autotest gradient of fun "
"with multiple outputs"
)
)
# we could make loop over outputs making random projections R for each,
# we could make loop over outputs making random projections R for each,
# but this doesn't handle the case where not all the outputs are
# but this doesn't handle the case where not all the outputs are
...
...
theano/ifelse.py
浏览文件 @
86dbc392
...
@@ -190,11 +190,9 @@ class IfElse(Op):
...
@@ -190,11 +190,9 @@ class IfElse(Op):
)
)
if
c
.
ndim
>
0
:
if
c
.
ndim
>
0
:
raise
TypeError
(
raise
TypeError
(
(
"Condition given to the op has to be a scalar "
"Condition given to the op has to be a scalar "
"with 0 standing for False, anything else "
"with 0 standing for False, anything else "
"for True"
"for True"
)
)
)
return
Apply
(
self
,
[
c
]
+
list
(
args
),
[
t
.
type
()
for
t
in
ts
])
return
Apply
(
self
,
[
c
]
+
list
(
args
),
[
t
.
type
()
for
t
in
ts
])
...
@@ -401,13 +399,11 @@ def ifelse(condition, then_branch, else_branch, name=None):
...
@@ -401,13 +399,11 @@ def ifelse(condition, then_branch, else_branch, name=None):
if
len
(
then_branch
)
!=
len
(
else_branch
):
if
len
(
then_branch
)
!=
len
(
else_branch
):
raise
ValueError
(
raise
ValueError
(
(
"The number of values on the `then` branch"
"The number of values on the `then` branch"
" should have the same number of variables as "
" should have the same number of variables as "
"the `else` branch : (variables on `then` "
"the `else` branch : (variables on `then` "
"
%
d"
%
len
(
then_branch
)
+
", variables on `else` "
"
%
d"
%
len
(
then_branch
)
+
", variables on `else` "
"
%
d"
%
len
(
else_branch
)
+
")"
"
%
d"
%
len
(
else_branch
)
+
")"
)
)
)
new_ifelse
=
IfElse
(
n_outs
=
len
(
then_branch
),
as_view
=
False
,
gpu
=
False
,
name
=
name
)
new_ifelse
=
IfElse
(
n_outs
=
len
(
then_branch
),
as_view
=
False
,
gpu
=
False
,
name
=
name
)
...
...
theano/pathparse.py
浏览文件 @
86dbc392
...
@@ -2,7 +2,7 @@ import os
...
@@ -2,7 +2,7 @@ import os
import
sys
import
sys
class
PathParser
(
object
)
:
class
PathParser
:
"""
"""
Class that allows to modify system's PATH environment variable
Class that allows to modify system's PATH environment variable
at runtime. Currently used in ``theano.gpuarray.dnn`` module
at runtime. Currently used in ``theano.gpuarray.dnn`` module
...
...
theano/printing.py
浏览文件 @
86dbc392
...
@@ -12,7 +12,6 @@ from copy import copy
...
@@ -12,7 +12,6 @@ from copy import copy
from
functools
import
reduce
from
functools
import
reduce
import
numpy
as
np
import
numpy
as
np
from
six
import
integer_types
,
string_types
from
six.moves
import
StringIO
from
six.moves
import
StringIO
import
theano
import
theano
...
@@ -55,7 +54,7 @@ except ImportError:
...
@@ -55,7 +54,7 @@ except ImportError:
_logger
=
logging
.
getLogger
(
"theano.printing"
)
_logger
=
logging
.
getLogger
(
"theano.printing"
)
VALID_ASSOC
=
set
([
"left"
,
"right"
,
"either"
])
VALID_ASSOC
=
{
"left"
,
"right"
,
"either"
}
def
debugprint
(
def
debugprint
(
...
@@ -121,7 +120,7 @@ def debugprint(
...
@@ -121,7 +120,7 @@ def debugprint(
to the Apply's identifier, to indicate which output a line corresponds to.
to the Apply's identifier, to indicate which output a line corresponds to.
"""
"""
if
not
isinstance
(
depth
,
int
eger_types
):
if
not
isinstance
(
depth
,
int
):
raise
Exception
(
"depth parameter must be an int"
)
raise
Exception
(
"depth parameter must be an int"
)
if
file
==
"str"
:
if
file
==
"str"
:
_file
=
StringIO
()
_file
=
StringIO
()
...
@@ -168,7 +167,7 @@ def debugprint(
...
@@ -168,7 +167,7 @@ def debugprint(
smap
.
extend
([
getattr
(
obj
,
"storage_map"
,
None
)
for
item
in
obj
.
outputs
])
smap
.
extend
([
getattr
(
obj
,
"storage_map"
,
None
)
for
item
in
obj
.
outputs
])
topo
=
obj
.
toposort
()
topo
=
obj
.
toposort
()
order
.
extend
([
topo
for
item
in
obj
.
outputs
])
order
.
extend
([
topo
for
item
in
obj
.
outputs
])
elif
isinstance
(
obj
,
(
int
eger_types
,
float
,
np
.
ndarray
)):
elif
isinstance
(
obj
,
(
int
,
float
,
np
.
ndarray
)):
print
(
obj
,
file
=
_file
)
print
(
obj
,
file
=
_file
)
elif
isinstance
(
obj
,
(
theano
.
In
,
theano
.
Out
)):
elif
isinstance
(
obj
,
(
theano
.
In
,
theano
.
Out
)):
results_to_print
.
append
(
obj
.
variable
)
results_to_print
.
append
(
obj
.
variable
)
...
@@ -239,14 +238,10 @@ N.B.:
...
@@ -239,14 +238,10 @@ N.B.:
else
:
else
:
inner_inputs
=
s
.
owner
.
op
.
inputs
inner_inputs
=
s
.
owner
.
op
.
inputs
outer_inputs
=
s
.
owner
.
inputs
outer_inputs
=
s
.
owner
.
inputs
inner_to_outer_inputs
=
dict
(
inner_to_outer_inputs
=
{
[
inner_inputs
[
i
]:
outer_inputs
[
o
]
(
inner_inputs
[
i
],
outer_inputs
[
o
])
for
i
,
o
in
s
.
owner
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
.
items
()
for
i
,
o
in
s
.
owner
.
op
.
var_mappings
[
}
"outer_inp_from_inner_inp"
]
.
items
()
]
)
print
(
""
,
file
=
_file
)
print
(
""
,
file
=
_file
)
debugmode
.
debugprint
(
debugmode
.
debugprint
(
...
@@ -440,7 +435,7 @@ class PatternPrinter:
...
@@ -440,7 +435,7 @@ class PatternPrinter:
def
__init__
(
self
,
*
patterns
):
def
__init__
(
self
,
*
patterns
):
self
.
patterns
=
[]
self
.
patterns
=
[]
for
pattern
in
patterns
:
for
pattern
in
patterns
:
if
isinstance
(
pattern
,
str
ing_types
):
if
isinstance
(
pattern
,
str
):
self
.
patterns
.
append
((
pattern
,
()))
self
.
patterns
.
append
((
pattern
,
()))
else
:
else
:
self
.
patterns
.
append
((
pattern
[
0
],
pattern
[
1
:]))
self
.
patterns
.
append
((
pattern
[
0
],
pattern
[
1
:]))
...
@@ -469,13 +464,13 @@ class PatternPrinter:
...
@@ -469,13 +464,13 @@ class PatternPrinter:
return
r
return
r
d
=
dict
(
d
=
{
(
str
(
i
),
x
)
str
(
i
):
x
for
i
,
x
in
enumerate
(
for
i
,
x
in
enumerate
(
pp_process
(
input
,
precedence
)
pp_process
(
input
,
precedence
)
for
input
,
precedence
in
zip
(
node
.
inputs
,
precedences
)
for
input
,
precedence
in
zip
(
node
.
inputs
,
precedences
)
)
)
)
}
r
=
pattern
%
d
r
=
pattern
%
d
pstate
.
memo
[
output
]
=
r
pstate
.
memo
[
output
]
=
r
return
r
return
r
...
@@ -501,7 +496,7 @@ class FunctionPrinter:
...
@@ -501,7 +496,7 @@ class FunctionPrinter:
try
:
try
:
old_precedence
=
getattr
(
pstate
,
"precedence"
,
None
)
old_precedence
=
getattr
(
pstate
,
"precedence"
,
None
)
pstate
.
precedence
=
new_precedence
pstate
.
precedence
=
new_precedence
r
=
"
%
s(
%
s)"
%
(
r
=
"
{}({})"
.
format
(
name
,
name
,
", "
.
join
([
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]),
", "
.
join
([
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]),
)
)
...
@@ -556,7 +551,7 @@ class DefaultPrinter:
...
@@ -556,7 +551,7 @@ class DefaultPrinter:
try
:
try
:
old_precedence
=
getattr
(
pstate
,
"precedence"
,
None
)
old_precedence
=
getattr
(
pstate
,
"precedence"
,
None
)
pstate
.
precedence
=
new_precedence
pstate
.
precedence
=
new_precedence
r
=
"
%
s(
%
s)"
%
(
r
=
"
{}({})"
.
format
(
str
(
node
.
op
),
str
(
node
.
op
),
", "
.
join
([
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]),
", "
.
join
([
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]),
)
)
...
@@ -624,7 +619,7 @@ class PPrinter:
...
@@ -624,7 +619,7 @@ class PPrinter:
pprinter
=
self
.
clone_assign
(
pprinter
=
self
.
clone_assign
(
lambda
pstate
,
r
:
r
.
name
is
not
None
and
r
is
not
current
,
leaf_printer
lambda
pstate
,
r
:
r
.
name
is
not
None
and
r
is
not
current
,
leaf_printer
)
)
inv_updates
=
dict
((
b
,
a
)
for
(
a
,
b
)
in
updates
.
items
())
inv_updates
=
{
b
:
a
for
(
a
,
b
)
in
updates
.
items
()}
i
=
1
i
=
1
for
node
in
gof
.
graph
.
io_toposort
(
for
node
in
gof
.
graph
.
io_toposort
(
list
(
inputs
)
+
updates
.
keys
(),
list
(
outputs
)
+
updates
.
values
()
list
(
inputs
)
+
updates
.
keys
(),
list
(
outputs
)
+
updates
.
values
()
...
@@ -633,7 +628,7 @@ class PPrinter:
...
@@ -633,7 +628,7 @@ class PPrinter:
if
output
in
inv_updates
:
if
output
in
inv_updates
:
name
=
str
(
inv_updates
[
output
])
name
=
str
(
inv_updates
[
output
])
strings
.
append
(
strings
.
append
(
(
i
+
1000
,
"
%
s <-
%
s"
%
(
name
,
pprinter
.
process
(
output
)))
(
i
+
1000
,
"
{} <- {}"
.
format
(
name
,
pprinter
.
process
(
output
)))
)
)
i
+=
1
i
+=
1
if
output
.
name
is
not
None
or
output
in
outputs
:
if
output
.
name
is
not
None
or
output
in
outputs
:
...
@@ -653,7 +648,7 @@ class PPrinter:
...
@@ -653,7 +648,7 @@ class PPrinter:
strings
.
append
((
idx
,
"return
%
s"
%
pprinter
.
process
(
output
)))
strings
.
append
((
idx
,
"return
%
s"
%
pprinter
.
process
(
output
)))
else
:
else
:
strings
.
append
(
strings
.
append
(
(
idx
,
"
%
s =
%
s"
%
(
name
,
pprinter
.
process
(
output
)))
(
idx
,
"
{} = {}"
.
format
(
name
,
pprinter
.
process
(
output
)))
)
)
i
+=
1
i
+=
1
strings
.
sort
()
strings
.
sort
()
...
@@ -901,7 +896,7 @@ def pydotprint(
...
@@ -901,7 +896,7 @@ def pydotprint(
dstr
=
"val="
+
str
(
np
.
asarray
(
var
.
data
))
dstr
=
"val="
+
str
(
np
.
asarray
(
var
.
data
))
if
"
\n
"
in
dstr
:
if
"
\n
"
in
dstr
:
dstr
=
dstr
[:
dstr
.
index
(
"
\n
"
)]
dstr
=
dstr
[:
dstr
.
index
(
"
\n
"
)]
varstr
=
"
%
s
%
s"
%
(
dstr
,
str
(
var
.
type
))
varstr
=
"
{} {}"
.
format
(
dstr
,
str
(
var
.
type
))
elif
var
in
input_update
and
input_update
[
var
]
.
name
is
not
None
:
elif
var
in
input_update
and
input_update
[
var
]
.
name
is
not
None
:
varstr
=
input_update
[
var
]
.
name
varstr
=
input_update
[
var
]
.
name
if
not
var_with_name_simple
:
if
not
var_with_name_simple
:
...
@@ -933,7 +928,7 @@ def pydotprint(
...
@@ -933,7 +928,7 @@ def pydotprint(
pf
=
0
pf
=
0
else
:
else
:
pf
=
time
*
100
/
profile
.
fct_call_time
pf
=
time
*
100
/
profile
.
fct_call_time
prof_str
=
" (
%.3
fs,
%.3
f
%%
)"
%
(
time
,
pf
)
prof_str
=
" (
{:.3f}s,{:.3f}
%
)"
.
format
(
time
,
pf
)
applystr
=
str
(
node
.
op
)
.
replace
(
":"
,
"_"
)
applystr
=
str
(
node
.
op
)
.
replace
(
":"
,
"_"
)
applystr
+=
prof_str
applystr
+=
prof_str
if
(
applystr
in
all_strings
)
or
with_ids
:
if
(
applystr
in
all_strings
)
or
with_ids
:
...
...
theano/raise_op.py
浏览文件 @
86dbc392
...
@@ -25,7 +25,7 @@ class Raise(gof.Op):
...
@@ -25,7 +25,7 @@ class Raise(gof.Op):
self
.
exc
=
exc
self
.
exc
=
exc
def
__str__
(
self
):
def
__str__
(
self
):
return
"Raise{
%
s(
%
s)}"
%
(
self
.
exc
,
self
.
msg
)
return
"Raise{
{{}({})}}"
.
format
(
self
.
exc
,
self
.
msg
)
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
...
theano/updates.py
浏览文件 @
86dbc392
...
@@ -42,7 +42,7 @@ class OrderedUpdates(OrderedDict):
...
@@ -42,7 +42,7 @@ class OrderedUpdates(OrderedDict):
"an OrderedDict that is available at "
"an OrderedDict that is available at "
"theano.compat.OrderedDict for python 2.6+."
"theano.compat.OrderedDict for python 2.6+."
)
)
super
(
OrderedUpdates
,
self
)
.
__init__
(
*
key
,
**
kwargs
)
super
()
.
__init__
(
*
key
,
**
kwargs
)
for
key
in
self
:
for
key
in
self
:
if
not
isinstance
(
key
,
SharedVariable
):
if
not
isinstance
(
key
,
SharedVariable
):
raise
TypeError
(
raise
TypeError
(
...
@@ -59,7 +59,7 @@ class OrderedUpdates(OrderedDict):
...
@@ -59,7 +59,7 @@ class OrderedUpdates(OrderedDict):
# value. Should it be cast to a GPU value right away? Should
# value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately?
# literals be transformed into constants immediately?
return
super
(
OrderedUpdates
,
self
)
.
__setitem__
(
key
,
value
)
return
super
()
.
__setitem__
(
key
,
value
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"OrderedUpdates keys must inherit from "
"SharedVariable"
,
key
"OrderedUpdates keys must inherit from "
"SharedVariable"
,
key
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论