Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
86dbc392
提交
86dbc392
authored
10月 21, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Apply pyupgrade to top-level modules in theano package
上级
d6477d58
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
111 行增加
和
148 行删除
+111
-148
_version.py
theano/_version.py
+8
-8
compat.py
theano/compat.py
+27
-42
configdefaults.py
theano/configdefaults.py
+13
-16
configparser.py
theano/configparser.py
+13
-15
gradient.py
theano/gradient.py
+20
-28
ifelse.py
theano/ifelse.py
+8
-12
pathparse.py
theano/pathparse.py
+1
-1
printing.py
theano/printing.py
+18
-23
raise_op.py
theano/raise_op.py
+1
-1
updates.py
theano/updates.py
+2
-2
没有找到文件。
theano/_version.py
浏览文件 @
86dbc392
...
...
@@ -84,7 +84,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
stderr
=
(
subprocess
.
PIPE
if
hide_stderr
else
None
),
)
break
except
Environment
Error
:
except
OS
Error
:
e
=
sys
.
exc_info
()[
1
]
if
e
.
errno
==
errno
.
ENOENT
:
continue
...
...
@@ -94,7 +94,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
return
None
,
None
else
:
if
verbose
:
print
(
"unable to find command, tried
%
s"
%
(
commands
,
))
print
(
"unable to find command, tried
{}"
.
format
(
commands
))
return
None
,
None
stdout
=
p
.
communicate
()[
0
]
.
strip
()
.
decode
()
if
p
.
returncode
!=
0
:
...
...
@@ -145,7 +145,7 @@ def git_get_keywords(versionfile_abs):
# _version.py.
keywords
=
{}
try
:
f
=
open
(
versionfile_abs
,
"r"
)
f
=
open
(
versionfile_abs
)
for
line
in
f
.
readlines
():
if
line
.
strip
()
.
startswith
(
"git_refnames ="
):
mo
=
re
.
search
(
r'=\s*"(.*)"'
,
line
)
...
...
@@ -160,7 +160,7 @@ def git_get_keywords(versionfile_abs):
if
mo
:
keywords
[
"date"
]
=
mo
.
group
(
1
)
f
.
close
()
except
Environment
Error
:
except
OS
Error
:
pass
return
keywords
...
...
@@ -184,11 +184,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if
verbose
:
print
(
"keywords are unexpanded, not using"
)
raise
NotThisMethod
(
"unexpanded keywords, not a git-archive tarball"
)
refs
=
set
([
r
.
strip
()
for
r
in
refnames
.
strip
(
"()"
)
.
split
(
","
)])
refs
=
{
r
.
strip
()
for
r
in
refnames
.
strip
(
"()"
)
.
split
(
","
)}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG
=
"tag: "
tags
=
set
([
r
[
len
(
TAG
)
:]
for
r
in
refs
if
r
.
startswith
(
TAG
)])
tags
=
{
r
[
len
(
TAG
)
:]
for
r
in
refs
if
r
.
startswith
(
TAG
)}
if
not
tags
:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
...
...
@@ -197,7 +197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags
=
set
([
r
for
r
in
refs
if
re
.
search
(
r"\d"
,
r
)])
tags
=
{
r
for
r
in
refs
if
re
.
search
(
r"\d"
,
r
)}
if
verbose
:
print
(
"discarding '
%
s', no digits"
%
","
.
join
(
refs
-
tags
))
if
verbose
:
...
...
@@ -300,7 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if
verbose
:
fmt
=
"tag '
%
s' doesn't start with prefix '
%
s'"
print
(
fmt
%
(
full_tag
,
tag_prefix
))
pieces
[
"error"
]
=
"tag '
%
s' doesn't start with prefix '
%
s'"
%
(
pieces
[
"error"
]
=
"tag '
{}' doesn't start with prefix '{}'"
.
format
(
full_tag
,
tag_prefix
,
)
...
...
theano/compat.py
浏览文件 @
86dbc392
...
...
@@ -5,7 +5,7 @@
from
collections
import
OrderedDict
# Python 3.x compatibility
from
six
import
PY3
,
BytesIO
,
b
,
next
from
six
import
PY3
,
BytesIO
,
b
from
six.moves
import
configparser
from
six.moves
import
reload_module
as
reload
...
...
@@ -24,59 +24,44 @@ except ImportError:
__all__
=
[
"PY3"
,
"b"
,
"BytesIO"
,
"next"
,
"configparser"
,
"reload"
]
if
PY3
:
from
operator
import
truediv
as
operator_div
from
operator
import
truediv
as
operator_div
# In python 3.x, when an exception is reraised it saves original
# exception in its args, therefore in order to find the actual
# message, we need to unpack arguments recursively.
def
exc_message
(
e
):
msg
=
e
.
args
[
0
]
if
isinstance
(
msg
,
Exception
):
return
exc_message
(
msg
)
return
msg
def
cmp
(
x
,
y
):
"""Return -1 if x < y, 0 if x == y, 1 if x > y."""
return
(
x
>
y
)
-
(
x
<
y
)
# In python 3.x, when an exception is reraised it saves original
# exception in its args, therefore in order to find the actual
# message, we need to unpack arguments recursively.
def
exc_message
(
e
):
msg
=
e
.
args
[
0
]
if
isinstance
(
msg
,
Exception
):
return
exc_message
(
msg
)
return
msg
def
get_unbound_function
(
unbound
):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if
hasattr
(
unbound
,
"__func__"
):
return
unbound
.
__func__
return
unbound
def
decode
(
x
):
return
x
.
decode
()
def
cmp
(
x
,
y
):
"""Return -1 if x < y, 0 if x == y, 1 if x > y."""
return
(
x
>
y
)
-
(
x
<
y
)
def
decode_iter
(
itr
):
for
x
in
itr
:
yield
x
.
decode
()
def
decode_with
(
x
,
encoding
):
return
x
.
decode
(
encoding
)
def
get_unbound_function
(
unbound
):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if
hasattr
(
unbound
,
"__func__"
):
return
unbound
.
__func__
return
unbound
else
:
from
operator
import
div
as
operator_div
def
decode
(
x
)
:
return
x
.
decode
()
from
six
import
get_unbound_function
def
exc_message
(
e
):
return
e
[
0
]
def
decode_iter
(
itr
):
for
x
in
itr
:
yield
x
.
decode
()
cmp
=
cmp
def
decode
(
x
):
return
x
def
decode_iter
(
x
):
return
x
def
decode_with
(
x
,
encoding
):
return
x
def
decode_with
(
x
,
encoding
):
return
x
.
decode
(
encoding
)
__all__
+=
[
...
...
theano/configdefaults.py
浏览文件 @
86dbc392
...
...
@@ -10,7 +10,6 @@ import textwrap
import
warnings
import
numpy
as
np
from
six
import
string_types
import
theano
from
theano.compat
import
maybe_add_to_os_environ_pathlist
...
...
@@ -142,18 +141,16 @@ class DeviceParam(ConfigParam):
)
else
:
raise
ValueError
(
(
'Invalid value ("
%
s") for configuration '
'variable "
%
s". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
%
(
val
,
self
.
fullname
)
)
'Invalid value ("
%
s") for configuration '
'variable "
%
s". Valid options start with '
'one of "cpu", "opencl" or "cuda".'
%
(
val
,
self
.
fullname
)
)
over
=
kwargs
.
get
(
"allow_override"
,
True
)
super
(
DeviceParam
,
self
)
.
__init__
(
default
,
filter
,
over
)
super
()
.
__init__
(
default
,
filter
,
over
)
def
__str__
(
self
):
return
"
%
s (
%
s, opencl*, cuda*) "
%
(
self
.
fullname
,
self
.
default
)
return
"
{} ({}, opencl*, cuda*) "
.
format
(
self
.
fullname
,
self
.
default
)
AddConfigVar
(
...
...
@@ -211,13 +208,13 @@ class ContextsParam(ConfigParam):
for
v
in
val
.
split
(
";"
):
s
=
v
.
split
(
"->"
)
if
len
(
s
)
!=
2
:
raise
ValueError
(
"Malformed context map:
%
s"
%
(
v
,
))
raise
ValueError
(
"Malformed context map:
{}"
.
format
(
v
))
if
(
s
[
0
]
==
"cpu"
or
s
[
0
]
.
startswith
(
"cuda"
)
or
s
[
0
]
.
startswith
(
"opencl"
)
):
raise
ValueError
(
"Cannot use
%
s as context name"
%
(
s
[
0
],
))
raise
ValueError
(
"Cannot use
{} as context name"
.
format
(
s
[
0
]
))
return
val
ConfigParam
.
__init__
(
self
,
""
,
filter
,
False
)
...
...
@@ -1409,7 +1406,7 @@ AddConfigVar(
def
is_valid_check_preallocated_output_param
(
param
):
if
not
isinstance
(
param
,
str
ing_types
):
if
not
isinstance
(
param
,
str
):
return
False
valid
=
[
"initial"
,
...
...
@@ -1821,7 +1818,7 @@ def default_blas_ldflags():
# we just pass the whole ldflags as the -l
# options part.
[
"-L
%
s
%
s
%
s"
%
(
path_wrapper
,
l
,
path_wrapper
)
"-L
{}{}{}"
.
format
(
path_wrapper
,
l
,
path_wrapper
)
for
l
in
blas_info
.
get
(
"library_dirs"
,
[])
]
+
[
"-l
%
s"
%
l
for
l
in
blas_info
.
get
(
"libraries"
,
[])]
...
...
@@ -1902,7 +1899,7 @@ def try_blas_flag(flags):
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
cflags
.
extend
(
[
"-L
%
s
%
s
%
s"
%
(
path_wrapper
,
d
,
path_wrapper
)
"-L
{}{}{}"
.
format
(
path_wrapper
,
d
,
path_wrapper
)
for
d
in
theano
.
gof
.
cmodule
.
std_lib_dirs
()
]
)
...
...
@@ -2311,11 +2308,11 @@ def filter_compiledir(path):
if
not
os
.
path
.
exists
(
init_file
):
try
:
open
(
init_file
,
"w"
)
.
close
()
except
IO
Error
as
e
:
except
OS
Error
as
e
:
if
os
.
path
.
exists
(
init_file
):
pass
# has already been created
else
:
e
.
args
+=
(
"
%
s exist?
%
s"
%
(
path
,
os
.
path
.
exists
(
path
)),)
e
.
args
+=
(
"
{} exist? {}"
.
format
(
path
,
os
.
path
.
exists
(
path
)),)
raise
return
path
...
...
@@ -2390,4 +2387,4 @@ AddConfigVar(
# Check if there are remaining flags provided by the user through THEANO_FLAGS.
for
key
in
THEANO_FLAGS_DICT
.
keys
():
warnings
.
warn
(
"Theano does not recognise this flag: {
0
}"
.
format
(
key
))
warnings
.
warn
(
"Theano does not recognise this flag: {}"
.
format
(
key
))
theano/configparser.py
浏览文件 @
86dbc392
...
...
@@ -10,7 +10,7 @@ import sys
import
warnings
from
functools
import
wraps
from
six
import
PY3
,
StringIO
,
string_types
from
six
import
PY3
,
StringIO
import
theano
from
theano.compat
import
configparser
as
ConfigParser
...
...
@@ -95,7 +95,7 @@ theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg
.
read
(
config_files
)
class
change_flags
(
object
)
:
class
change_flags
:
"""
Use this as a decorator or context manager to change the value of
Theano config variables.
...
...
@@ -204,12 +204,12 @@ def get_config_hash():
)
return
theano
.
gof
.
utils
.
hash_from_code
(
"
\n
"
.
join
(
[
"
%
s =
%
s"
%
(
cv
.
fullname
,
cv
.
__get__
(
True
,
None
))
for
cv
in
all_opts
]
[
"
{} = {}"
.
format
(
cv
.
fullname
,
cv
.
__get__
(
True
,
None
))
for
cv
in
all_opts
]
)
)
class
TheanoConfigParser
(
object
)
:
class
TheanoConfigParser
:
# properties are installed by AddConfigVar
_i_am_a_config_class
=
True
...
...
@@ -276,7 +276,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
if
not
hasattr
(
root
,
sections
[
0
]):
# every internal node in the config tree is an instance of its own
# unique class
class
SubObj
(
object
)
:
class
SubObj
:
_i_am_a_config_class
=
True
setattr
(
root
.
__class__
,
sections
[
0
],
SubObj
())
...
...
@@ -312,7 +312,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
_config_var_list
.
append
(
configparam
)
class
ConfigParam
(
object
)
:
class
ConfigParam
:
def
__init__
(
self
,
default
,
filter
=
None
,
allow_override
=
True
):
"""
If allow_override is False, we can't change the value after the import
...
...
@@ -368,7 +368,7 @@ class EnumStr(ConfigParam):
# All options should be strings
for
val
in
self
.
all
:
if
not
isinstance
(
val
,
str
ing_types
):
if
not
isinstance
(
val
,
str
):
raise
ValueError
(
"Valid values for an EnumStr parameter "
"should be strings"
,
val
,
...
...
@@ -384,17 +384,15 @@ class EnumStr(ConfigParam):
return
val
else
:
raise
ValueError
(
(
'Invalid value ("
%
s") for configuration variable "
%
s". '
"Valid options are
%
s"
%
(
val
,
self
.
fullname
,
self
.
all
)
)
'Invalid value ("
%
s") for configuration variable "
%
s". '
"Valid options are
%
s"
%
(
val
,
self
.
fullname
,
self
.
all
)
)
over
=
kwargs
.
get
(
"allow_override"
,
True
)
super
(
EnumStr
,
self
)
.
__init__
(
default
,
filter
,
over
)
super
()
.
__init__
(
default
,
filter
,
over
)
def
__str__
(
self
):
return
"
%
s (
%
s) "
%
(
self
.
fullname
,
self
.
all
)
return
"
{} ({}) "
.
format
(
self
.
fullname
,
self
.
all
)
class
TypedParam
(
ConfigParam
):
...
...
@@ -414,10 +412,10 @@ class TypedParam(ConfigParam):
)
return
cast_val
super
(
TypedParam
,
self
)
.
__init__
(
default
,
filter
,
allow_override
=
allow_override
)
super
()
.
__init__
(
default
,
filter
,
allow_override
=
allow_override
)
def
__str__
(
self
):
return
"
%
s (
%
s) "
%
(
self
.
fullname
,
self
.
mytype
)
return
"
{} ({}) "
.
format
(
self
.
fullname
,
self
.
mytype
)
def
StrParam
(
default
,
is_valid
=
None
,
allow_override
=
True
):
...
...
theano/gradient.py
浏览文件 @
86dbc392
...
...
@@ -130,20 +130,16 @@ class DisconnectedType(theano.gof.type.Type):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
raise
AssertionError
(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
def
fiter_variable
(
self
,
other
):
raise
AssertionError
(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
def
may_share_memory
(
a
,
b
):
...
...
@@ -151,11 +147,9 @@ class DisconnectedType(theano.gof.type.Type):
def
value_eq
(
a
,
b
,
force_same_dtype
=
True
):
raise
AssertionError
(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
def
__str__
(
self
):
...
...
@@ -846,7 +840,7 @@ def _node_to_pattern(node):
raise
TypeError
(
"
%
s.connection_pattern should return"
%
node
.
op
+
" a list of lists, but element
%
d"
%
ii
+
"is
%
s of type
%
s."
%
(
output_pattern
,
type
(
output_pattern
))
+
"is
{} of type {}."
.
format
(
output_pattern
,
type
(
output_pattern
))
)
else
:
connection_pattern
=
[[
True
for
output
in
node
.
outputs
]
for
ipt
in
node
.
inputs
]
...
...
@@ -933,7 +927,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# Note: we need to revisit the apply nodes repeatedly, because
# different outputs of the apply node are connected to
# different subsets of the inputs.
accounted_for
=
set
(
[]
)
accounted_for
=
set
()
def
account_for
(
var
):
# Don't visit the same variable twice
...
...
@@ -984,7 +978,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# determine which variables have elements of wrt as a true
# ancestor. Do this with an upward pass starting from wrt,
# following only true connections
visited
=
set
(
[]
)
visited
=
set
()
def
visit
(
var
):
if
var
in
visited
:
...
...
@@ -1458,7 +1452,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
grad_dict
[
var
]
=
disconnected_type
()
if
cost_name
is
not
None
and
var
.
name
is
not
None
:
grad_dict
[
var
]
.
name
=
"(d
%
s/d
%
s)"
%
(
cost_name
,
var
.
name
)
grad_dict
[
var
]
.
name
=
"(d
{}/d{})"
.
format
(
cost_name
,
var
.
name
)
else
:
# this variable isn't connected to the cost in the
# computational graph
...
...
@@ -1494,7 +1488,7 @@ def _float_ones_like(x):
return
x
.
ones_like
(
dtype
=
dtype
)
class
numeric_grad
(
object
)
:
class
numeric_grad
:
"""
Compute the numeric derivative of a scalar-valued function at a particular
point.
...
...
@@ -1818,13 +1812,11 @@ def verify_grad(
if
rng
is
None
:
raise
TypeError
(
(
"rng should be a valid instance of "
"numpy.random.RandomState. You may "
"want to use tests.unittest"
"_tools.verify_grad instead of "
"theano.gradient.verify_grad."
)
"rng should be a valid instance of "
"numpy.random.RandomState. You may "
"want to use tests.unittest"
"_tools.verify_grad instead of "
"theano.gradient.verify_grad."
)
# We allow input downcast in function, because numeric_grad works in the
...
...
@@ -1853,7 +1845,7 @@ def verify_grad(
if
isinstance
(
o_output
,
list
):
raise
NotImplementedError
(
(
"cant (yet) autotest gradient of fun "
"with multiple outputs"
)
"cant (yet) autotest gradient of fun "
"with multiple outputs"
)
# we could make loop over outputs making random projections R for each,
# but this doesn't handle the case where not all the outputs are
...
...
theano/ifelse.py
浏览文件 @
86dbc392
...
...
@@ -190,11 +190,9 @@ class IfElse(Op):
)
if
c
.
ndim
>
0
:
raise
TypeError
(
(
"Condition given to the op has to be a scalar "
"with 0 standing for False, anything else "
"for True"
)
"Condition given to the op has to be a scalar "
"with 0 standing for False, anything else "
"for True"
)
return
Apply
(
self
,
[
c
]
+
list
(
args
),
[
t
.
type
()
for
t
in
ts
])
...
...
@@ -401,13 +399,11 @@ def ifelse(condition, then_branch, else_branch, name=None):
if
len
(
then_branch
)
!=
len
(
else_branch
):
raise
ValueError
(
(
"The number of values on the `then` branch"
" should have the same number of variables as "
"the `else` branch : (variables on `then` "
"
%
d"
%
len
(
then_branch
)
+
", variables on `else` "
"
%
d"
%
len
(
else_branch
)
+
")"
)
"The number of values on the `then` branch"
" should have the same number of variables as "
"the `else` branch : (variables on `then` "
"
%
d"
%
len
(
then_branch
)
+
", variables on `else` "
"
%
d"
%
len
(
else_branch
)
+
")"
)
new_ifelse
=
IfElse
(
n_outs
=
len
(
then_branch
),
as_view
=
False
,
gpu
=
False
,
name
=
name
)
...
...
theano/pathparse.py
浏览文件 @
86dbc392
...
...
@@ -2,7 +2,7 @@ import os
import
sys
class
PathParser
(
object
)
:
class
PathParser
:
"""
Class that allows to modify system's PATH environment variable
at runtime. Currently used in ``theano.gpuarray.dnn`` module
...
...
theano/printing.py
浏览文件 @
86dbc392
...
...
@@ -12,7 +12,6 @@ from copy import copy
from
functools
import
reduce
import
numpy
as
np
from
six
import
integer_types
,
string_types
from
six.moves
import
StringIO
import
theano
...
...
@@ -55,7 +54,7 @@ except ImportError:
_logger
=
logging
.
getLogger
(
"theano.printing"
)
VALID_ASSOC
=
set
([
"left"
,
"right"
,
"either"
])
VALID_ASSOC
=
{
"left"
,
"right"
,
"either"
}
def
debugprint
(
...
...
@@ -121,7 +120,7 @@ def debugprint(
to the Apply's identifier, to indicate which output a line corresponds to.
"""
if
not
isinstance
(
depth
,
int
eger_types
):
if
not
isinstance
(
depth
,
int
):
raise
Exception
(
"depth parameter must be an int"
)
if
file
==
"str"
:
_file
=
StringIO
()
...
...
@@ -168,7 +167,7 @@ def debugprint(
smap
.
extend
([
getattr
(
obj
,
"storage_map"
,
None
)
for
item
in
obj
.
outputs
])
topo
=
obj
.
toposort
()
order
.
extend
([
topo
for
item
in
obj
.
outputs
])
elif
isinstance
(
obj
,
(
int
eger_types
,
float
,
np
.
ndarray
)):
elif
isinstance
(
obj
,
(
int
,
float
,
np
.
ndarray
)):
print
(
obj
,
file
=
_file
)
elif
isinstance
(
obj
,
(
theano
.
In
,
theano
.
Out
)):
results_to_print
.
append
(
obj
.
variable
)
...
...
@@ -239,14 +238,10 @@ N.B.:
else
:
inner_inputs
=
s
.
owner
.
op
.
inputs
outer_inputs
=
s
.
owner
.
inputs
inner_to_outer_inputs
=
dict
(
[
(
inner_inputs
[
i
],
outer_inputs
[
o
])
for
i
,
o
in
s
.
owner
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
.
items
()
]
)
inner_to_outer_inputs
=
{
inner_inputs
[
i
]:
outer_inputs
[
o
]
for
i
,
o
in
s
.
owner
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
.
items
()
}
print
(
""
,
file
=
_file
)
debugmode
.
debugprint
(
...
...
@@ -440,7 +435,7 @@ class PatternPrinter:
def
__init__
(
self
,
*
patterns
):
self
.
patterns
=
[]
for
pattern
in
patterns
:
if
isinstance
(
pattern
,
str
ing_types
):
if
isinstance
(
pattern
,
str
):
self
.
patterns
.
append
((
pattern
,
()))
else
:
self
.
patterns
.
append
((
pattern
[
0
],
pattern
[
1
:]))
...
...
@@ -469,13 +464,13 @@ class PatternPrinter:
return
r
d
=
dict
(
(
str
(
i
),
x
)
d
=
{
str
(
i
):
x
for
i
,
x
in
enumerate
(
pp_process
(
input
,
precedence
)
for
input
,
precedence
in
zip
(
node
.
inputs
,
precedences
)
)
)
}
r
=
pattern
%
d
pstate
.
memo
[
output
]
=
r
return
r
...
...
@@ -501,7 +496,7 @@ class FunctionPrinter:
try
:
old_precedence
=
getattr
(
pstate
,
"precedence"
,
None
)
pstate
.
precedence
=
new_precedence
r
=
"
%
s(
%
s)"
%
(
r
=
"
{}({})"
.
format
(
name
,
", "
.
join
([
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]),
)
...
...
@@ -556,7 +551,7 @@ class DefaultPrinter:
try
:
old_precedence
=
getattr
(
pstate
,
"precedence"
,
None
)
pstate
.
precedence
=
new_precedence
r
=
"
%
s(
%
s)"
%
(
r
=
"
{}({})"
.
format
(
str
(
node
.
op
),
", "
.
join
([
pprinter
.
process
(
input
,
pstate
)
for
input
in
node
.
inputs
]),
)
...
...
@@ -624,7 +619,7 @@ class PPrinter:
pprinter
=
self
.
clone_assign
(
lambda
pstate
,
r
:
r
.
name
is
not
None
and
r
is
not
current
,
leaf_printer
)
inv_updates
=
dict
((
b
,
a
)
for
(
a
,
b
)
in
updates
.
items
())
inv_updates
=
{
b
:
a
for
(
a
,
b
)
in
updates
.
items
()}
i
=
1
for
node
in
gof
.
graph
.
io_toposort
(
list
(
inputs
)
+
updates
.
keys
(),
list
(
outputs
)
+
updates
.
values
()
...
...
@@ -633,7 +628,7 @@ class PPrinter:
if
output
in
inv_updates
:
name
=
str
(
inv_updates
[
output
])
strings
.
append
(
(
i
+
1000
,
"
%
s <-
%
s"
%
(
name
,
pprinter
.
process
(
output
)))
(
i
+
1000
,
"
{} <- {}"
.
format
(
name
,
pprinter
.
process
(
output
)))
)
i
+=
1
if
output
.
name
is
not
None
or
output
in
outputs
:
...
...
@@ -653,7 +648,7 @@ class PPrinter:
strings
.
append
((
idx
,
"return
%
s"
%
pprinter
.
process
(
output
)))
else
:
strings
.
append
(
(
idx
,
"
%
s =
%
s"
%
(
name
,
pprinter
.
process
(
output
)))
(
idx
,
"
{} = {}"
.
format
(
name
,
pprinter
.
process
(
output
)))
)
i
+=
1
strings
.
sort
()
...
...
@@ -901,7 +896,7 @@ def pydotprint(
dstr
=
"val="
+
str
(
np
.
asarray
(
var
.
data
))
if
"
\n
"
in
dstr
:
dstr
=
dstr
[:
dstr
.
index
(
"
\n
"
)]
varstr
=
"
%
s
%
s"
%
(
dstr
,
str
(
var
.
type
))
varstr
=
"
{} {}"
.
format
(
dstr
,
str
(
var
.
type
))
elif
var
in
input_update
and
input_update
[
var
]
.
name
is
not
None
:
varstr
=
input_update
[
var
]
.
name
if
not
var_with_name_simple
:
...
...
@@ -933,7 +928,7 @@ def pydotprint(
pf
=
0
else
:
pf
=
time
*
100
/
profile
.
fct_call_time
prof_str
=
" (
%.3
fs,
%.3
f
%%
)"
%
(
time
,
pf
)
prof_str
=
" (
{:.3f}s,{:.3f}
%
)"
.
format
(
time
,
pf
)
applystr
=
str
(
node
.
op
)
.
replace
(
":"
,
"_"
)
applystr
+=
prof_str
if
(
applystr
in
all_strings
)
or
with_ids
:
...
...
theano/raise_op.py
浏览文件 @
86dbc392
...
...
@@ -25,7 +25,7 @@ class Raise(gof.Op):
self
.
exc
=
exc
def
__str__
(
self
):
return
"Raise{
%
s(
%
s)}"
%
(
self
.
exc
,
self
.
msg
)
return
"Raise{
{{}({})}}"
.
format
(
self
.
exc
,
self
.
msg
)
def
make_node
(
self
,
x
):
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
...
theano/updates.py
浏览文件 @
86dbc392
...
...
@@ -42,7 +42,7 @@ class OrderedUpdates(OrderedDict):
"an OrderedDict that is available at "
"theano.compat.OrderedDict for python 2.6+."
)
super
(
OrderedUpdates
,
self
)
.
__init__
(
*
key
,
**
kwargs
)
super
()
.
__init__
(
*
key
,
**
kwargs
)
for
key
in
self
:
if
not
isinstance
(
key
,
SharedVariable
):
raise
TypeError
(
...
...
@@ -59,7 +59,7 @@ class OrderedUpdates(OrderedDict):
# value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately?
return
super
(
OrderedUpdates
,
self
)
.
__setitem__
(
key
,
value
)
return
super
()
.
__setitem__
(
key
,
value
)
else
:
raise
TypeError
(
"OrderedUpdates keys must inherit from "
"SharedVariable"
,
key
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论