Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ffa5e139
提交
ffa5e139
authored
10月 21, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Apply pyupgrade to theano.compile
上级
86dbc392
显示空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
77 行增加
和
99 行删除
+77
-99
debugmode.py
theano/compile/debugmode.py
+18
-22
function.py
theano/compile/function.py
+2
-4
function_module.py
theano/compile/function_module.py
+15
-20
io.py
theano/compile/io.py
+7
-9
mode.py
theano/compile/mode.py
+9
-11
monitormode.py
theano/compile/monitormode.py
+3
-3
nanguardmode.py
theano/compile/nanguardmode.py
+1
-3
ops.py
theano/compile/ops.py
+10
-11
pfunc.py
theano/compile/pfunc.py
+1
-3
profiling.py
theano/compile/profiling.py
+10
-10
sharedvalue.py
theano/compile/sharedvalue.py
+1
-3
没有找到文件。
theano/compile/debugmode.py
浏览文件 @
ffa5e139
...
@@ -37,7 +37,7 @@ _logger = logging.getLogger("theano.compile.debugmode")
...
@@ -37,7 +37,7 @@ _logger = logging.getLogger("theano.compile.debugmode")
# Filter to avoid duplicating optimization warnings
# Filter to avoid duplicating optimization warnings
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
prev_msgs
=
set
(
[]
)
prev_msgs
=
set
()
def
filter
(
self
,
record
):
def
filter
(
self
,
record
):
msg
=
record
.
getMessage
()
msg
=
record
.
getMessage
()
...
@@ -64,8 +64,6 @@ class DebugModeError(Exception):
...
@@ -64,8 +64,6 @@ class DebugModeError(Exception):
"""
"""
pass
class
BadThunkOutput
(
DebugModeError
):
class
BadThunkOutput
(
DebugModeError
):
"""
"""
...
@@ -99,7 +97,7 @@ class BadThunkOutput(DebugModeError):
...
@@ -99,7 +97,7 @@ class BadThunkOutput(DebugModeError):
"""
"""
def
__init__
(
self
,
r
,
thunk1
,
val1
,
thunk2
,
val2
,
inputs_val
=
()):
def
__init__
(
self
,
r
,
thunk1
,
val1
,
thunk2
,
val2
,
inputs_val
=
()):
super
(
BadThunkOutput
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
r
=
r
self
.
r
=
r
self
.
thunk1
=
thunk1
self
.
thunk1
=
thunk1
self
.
val1
=
val1
self
.
val1
=
val1
...
@@ -170,7 +168,7 @@ class BadDestroyMap(DebugModeError):
...
@@ -170,7 +168,7 @@ class BadDestroyMap(DebugModeError):
"""
"""
def
__init__
(
self
,
node
,
idx
,
old_val
,
new_val
,
perform
):
def
__init__
(
self
,
node
,
idx
,
old_val
,
new_val
,
perform
):
super
(
BadDestroyMap
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
node
=
node
self
.
node
=
node
self
.
idx
=
idx
self
.
idx
=
idx
self
.
old_val
=
old_val
self
.
old_val
=
old_val
...
@@ -254,7 +252,7 @@ class BadViewMap(DebugModeError):
...
@@ -254,7 +252,7 @@ class BadViewMap(DebugModeError):
def
__init__
(
def
__init__
(
self
,
node
,
output_idx
,
out_storage
,
in_alias_idx
=
None
,
out_alias_idx
=
None
self
,
node
,
output_idx
,
out_storage
,
in_alias_idx
=
None
,
out_alias_idx
=
None
):
):
super
(
BadViewMap
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
node
=
node
self
.
node
=
node
self
.
output_idx
=
output_idx
self
.
output_idx
=
output_idx
self
.
out_storage
=
out_storage
self
.
out_storage
=
out_storage
...
@@ -290,8 +288,6 @@ class StochasticOrder(DebugModeError):
...
@@ -290,8 +288,6 @@ class StochasticOrder(DebugModeError):
"""
"""
pass
class
InvalidValueError
(
DebugModeError
):
class
InvalidValueError
(
DebugModeError
):
"""
"""
...
@@ -304,7 +300,7 @@ class InvalidValueError(DebugModeError):
...
@@ -304,7 +300,7 @@ class InvalidValueError(DebugModeError):
"""
"""
def
__init__
(
self
,
r
,
v
=
None
,
client_node
=
None
,
hint
=
"none"
,
specific_hint
=
"none"
):
def
__init__
(
self
,
r
,
v
=
None
,
client_node
=
None
,
hint
=
"none"
,
specific_hint
=
"none"
):
super
(
InvalidValueError
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
r
=
r
self
.
r
=
r
self
.
v
=
v
self
.
v
=
v
self
.
client_node
=
client_node
self
.
client_node
=
client_node
...
@@ -718,7 +714,7 @@ def debugprint(
...
@@ -718,7 +714,7 @@ def debugprint(
else
:
else
:
outer_id_str
=
get_id_str
(
outer_r
)
outer_id_str
=
get_id_str
(
outer_r
)
print
(
print
(
"
%
s
%
s
%
s
%
s ->
%
s"
%
(
prefix
,
r
,
id_str
,
type_str
,
outer_id_str
),
"
{}{} {}{} -> {}"
.
format
(
prefix
,
r
,
id_str
,
type_str
,
outer_id_str
),
file
=
file
,
file
=
file
,
)
)
else
:
else
:
...
@@ -727,7 +723,7 @@ def debugprint(
...
@@ -727,7 +723,7 @@ def debugprint(
if
smap
:
if
smap
:
data
=
" "
+
str
(
smap
.
get
(
r
,
""
))
data
=
" "
+
str
(
smap
.
get
(
r
,
""
))
id_str
=
get_id_str
(
r
)
id_str
=
get_id_str
(
r
)
print
(
"
%
s
%
s
%
s
%
s
%
s"
%
(
prefix
,
r
,
id_str
,
type_str
,
data
),
file
=
file
)
print
(
"
{}{} {}{}{}"
.
format
(
prefix
,
r
,
id_str
,
type_str
,
data
),
file
=
file
)
return
file
return
file
...
@@ -1073,9 +1069,9 @@ def _find_bad_optimizations1(order, reasons, r_vals):
...
@@ -1073,9 +1069,9 @@ def _find_bad_optimizations1(order, reasons, r_vals):
for
i
,
node
in
enumerate
(
order
):
for
i
,
node
in
enumerate
(
order
):
program_position
[
node
]
=
i
program_position
[
node
]
=
i
for
new_r
in
node
.
outputs
:
for
new_r
in
node
.
outputs
:
equivalence_sets
.
setdefault
(
new_r
,
set
([
new_r
])
)
equivalence_sets
.
setdefault
(
new_r
,
{
new_r
}
)
for
reason
,
r
,
old_graph_str
,
new_graph_str
in
reasons
[
new_r
]:
for
reason
,
r
,
old_graph_str
,
new_graph_str
in
reasons
[
new_r
]:
equivalence_sets
[
new_r
]
.
update
(
equivalence_sets
.
setdefault
(
r
,
set
([
r
])
))
equivalence_sets
[
new_r
]
.
update
(
equivalence_sets
.
setdefault
(
r
,
{
r
}
))
for
er
in
equivalence_sets
[
r
]:
for
er
in
equivalence_sets
[
r
]:
equivalence_sets
[
er
]
=
equivalence_sets
[
new_r
]
equivalence_sets
[
er
]
=
equivalence_sets
[
new_r
]
...
@@ -1474,7 +1470,7 @@ def _check_preallocated_output(
...
@@ -1474,7 +1470,7 @@ def _check_preallocated_output(
):
):
_logger
.
debug
(
" name =
%
s"
,
name
)
_logger
.
debug
(
" name =
%
s"
,
name
)
thunk_name
=
"
%
s with
%
s output"
%
(
perform
,
name
)
thunk_name
=
"
{} with {} output"
.
format
(
perform
,
name
)
if
not
out_map
:
if
not
out_map
:
# Map is empty, there is no need to execute thunk() again
# Map is empty, there is no need to execute thunk() again
...
@@ -1541,7 +1537,7 @@ def _check_preallocated_output(
...
@@ -1541,7 +1537,7 @@ def _check_preallocated_output(
fn
.
maker
.
mode
=
backup_mode
fn
.
maker
.
mode
=
backup_mode
class
_FunctionGraphEvent
(
object
)
:
class
_FunctionGraphEvent
:
"""
"""
A record of an event in the life of an FunctionGraph.
A record of an event in the life of an FunctionGraph.
...
@@ -1613,7 +1609,7 @@ class _FunctionGraphEvent(object):
...
@@ -1613,7 +1609,7 @@ class _FunctionGraphEvent(object):
return
not
(
self
==
other
)
return
not
(
self
==
other
)
class
_VariableEquivalenceTracker
(
object
)
:
class
_VariableEquivalenceTracker
:
"""
"""
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems.
tries to detect problems.
...
@@ -1684,7 +1680,7 @@ class _VariableEquivalenceTracker(object):
...
@@ -1684,7 +1680,7 @@ class _VariableEquivalenceTracker(object):
else
:
else
:
for
r
in
node
.
outputs
:
for
r
in
node
.
outputs
:
assert
r
not
in
self
.
equiv
assert
r
not
in
self
.
equiv
self
.
equiv
[
r
]
=
set
([
r
])
self
.
equiv
[
r
]
=
{
r
}
self
.
all_variables_ever
.
append
(
r
)
self
.
all_variables_ever
.
append
(
r
)
self
.
reasons
.
setdefault
(
r
,
[])
self
.
reasons
.
setdefault
(
r
,
[])
self
.
replaced_by
.
setdefault
(
r
,
[])
self
.
replaced_by
.
setdefault
(
r
,
[])
...
@@ -1740,13 +1736,13 @@ class _VariableEquivalenceTracker(object):
...
@@ -1740,13 +1736,13 @@ class _VariableEquivalenceTracker(object):
if
r
in
self
.
equiv
:
if
r
in
self
.
equiv
:
r_set
=
self
.
equiv
[
r
]
r_set
=
self
.
equiv
[
r
]
else
:
else
:
r_set
=
self
.
equiv
.
setdefault
(
r
,
set
([
r
])
)
r_set
=
self
.
equiv
.
setdefault
(
r
,
{
r
}
)
self
.
all_variables_ever
.
append
(
r
)
self
.
all_variables_ever
.
append
(
r
)
if
new_r
in
self
.
equiv
:
if
new_r
in
self
.
equiv
:
new_r_set
=
self
.
equiv
[
new_r
]
new_r_set
=
self
.
equiv
[
new_r
]
else
:
else
:
new_r_set
=
self
.
equiv
.
setdefault
(
new_r
,
set
([
new_r
])
)
new_r_set
=
self
.
equiv
.
setdefault
(
new_r
,
{
new_r
}
)
self
.
all_variables_ever
.
append
(
new_r
)
self
.
all_variables_ever
.
append
(
new_r
)
assert
new_r
in
new_r_set
assert
new_r
in
new_r_set
...
@@ -1779,7 +1775,7 @@ default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)]
...
@@ -1779,7 +1775,7 @@ default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)]
# the external requirements of the .linker attribute of a mode
# the external requirements of the .linker attribute of a mode
# 1) it's a class instance
# 1) it's a class instance
# 2) it a has a .clone() method
# 2) it a has a .clone() method
class
_DummyLinker
(
object
)
:
class
_DummyLinker
:
# This is not a real linker anyway
# This is not a real linker anyway
def
clone
(
self
,
allow_gc
=
None
):
def
clone
(
self
,
allow_gc
=
None
):
return
self
return
self
...
@@ -2746,7 +2742,7 @@ class DebugMode(Mode):
...
@@ -2746,7 +2742,7 @@ class DebugMode(Mode):
linker
,
linker
,
)
)
super
(
DebugMode
,
self
)
.
__init__
(
optimizer
=
optimizer
,
linker
=
linker
)
super
()
.
__init__
(
optimizer
=
optimizer
,
linker
=
linker
)
if
stability_patience
is
not
None
:
if
stability_patience
is
not
None
:
self
.
stability_patience
=
stability_patience
self
.
stability_patience
=
stability_patience
...
@@ -2771,7 +2767,7 @@ class DebugMode(Mode):
...
@@ -2771,7 +2767,7 @@ class DebugMode(Mode):
raise
ValueError
(
"DebugMode has to check at least one of c and py "
"code"
)
raise
ValueError
(
"DebugMode has to check at least one of c and py "
"code"
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"DebugMode(linker=
%
s, optimizer=
%
s)"
%
(
return
"DebugMode(linker=
{}, optimizer={})"
.
format
(
self
.
provided_linker
,
self
.
provided_linker
,
self
.
provided_optimizer
,
self
.
provided_optimizer
,
)
)
...
...
theano/compile/function.py
浏览文件 @
ffa5e139
...
@@ -9,8 +9,6 @@ import re
...
@@ -9,8 +9,6 @@ import re
import
traceback
as
tb
import
traceback
as
tb
import
warnings
import
warnings
from
six
import
string_types
from
theano
import
compat
from
theano
import
compat
from
theano.compile.function_module
import
orig_function
from
theano.compile.function_module
import
orig_function
from
theano.compile.pfunc
import
pfunc
from
theano.compile.pfunc
import
pfunc
...
@@ -67,7 +65,7 @@ def function_dump(
...
@@ -67,7 +65,7 @@ def function_dump(
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
"""
"""
assert
isinstance
(
filename
,
str
ing_types
)
assert
isinstance
(
filename
,
str
)
d
=
dict
(
d
=
dict
(
inputs
=
inputs
,
inputs
=
inputs
,
outputs
=
outputs
,
outputs
=
outputs
,
...
@@ -258,7 +256,7 @@ def function(
...
@@ -258,7 +256,7 @@ def function(
output_items
=
list
(
outputs
.
items
())
output_items
=
list
(
outputs
.
items
())
for
item_pair
in
output_items
:
for
item_pair
in
output_items
:
assert
isinstance
(
item_pair
[
0
],
str
ing_types
)
assert
isinstance
(
item_pair
[
0
],
str
)
output_items_sorted
=
sorted
(
output_items
)
output_items_sorted
=
sorted
(
output_items
)
...
...
theano/compile/function_module.py
浏览文件 @
ffa5e139
...
@@ -12,7 +12,6 @@ from itertools import chain
...
@@ -12,7 +12,6 @@ from itertools import chain
import
numpy
as
np
import
numpy
as
np
import
six.moves.copyreg
as
copyreg
import
six.moves.copyreg
as
copyreg
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
from
six
import
string_types
import
theano
import
theano
import
theano.compile.profiling
import
theano.compile.profiling
...
@@ -35,8 +34,6 @@ class UnusedInputError(Exception):
...
@@ -35,8 +34,6 @@ class UnusedInputError(Exception):
"""
"""
pass
def
alias_root
(
v
):
def
alias_root
(
v
):
"""
"""
...
@@ -94,7 +91,7 @@ def infer_reuse_pattern(fgraph, outputs_to_disown):
...
@@ -94,7 +91,7 @@ def infer_reuse_pattern(fgraph, outputs_to_disown):
for
o
in
outputs_to_disown
:
for
o
in
outputs_to_disown
:
view_tree_set
(
alias_root
(
o
),
rval
)
view_tree_set
(
alias_root
(
o
),
rval
)
# remove from rval all of the inputs, constants, values.
# remove from rval all of the inputs, constants, values.
rval
=
set
(
r
for
r
in
rval
if
r
.
owner
is
not
None
)
rval
=
{
r
for
r
in
rval
if
r
.
owner
is
not
None
}
return
rval
return
rval
...
@@ -219,8 +216,6 @@ class AliasedMemoryError(Exception):
...
@@ -219,8 +216,6 @@ class AliasedMemoryError(Exception):
"""
"""
pass
###
###
# Function
# Function
...
@@ -230,7 +225,7 @@ class AliasedMemoryError(Exception):
...
@@ -230,7 +225,7 @@ class AliasedMemoryError(Exception):
DUPLICATE
=
[
"DUPLICATE"
]
DUPLICATE
=
[
"DUPLICATE"
]
class
Function
(
object
)
:
class
Function
:
"""
"""
Type of the functions returned by theano.function or
Type of the functions returned by theano.function or
theano.FunctionMaker.create.
theano.FunctionMaker.create.
...
@@ -478,7 +473,7 @@ class Function(object):
...
@@ -478,7 +473,7 @@ class Function(object):
# this class is important in overriding the square-bracket notation:
# this class is important in overriding the square-bracket notation:
# fn.value[x]
# fn.value[x]
# self reference is available via the closure on the class
# self reference is available via the closure on the class
class
ValueAttribute
(
object
)
:
class
ValueAttribute
:
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
try
:
try
:
s
=
finder
[
item
]
s
=
finder
[
item
]
...
@@ -501,7 +496,9 @@ class Function(object):
...
@@ -501,7 +496,9 @@ class Function(object):
except
KeyError
:
except
KeyError
:
# Print informative error message.
# Print informative error message.
msg
=
get_info_on_inputs
(
named_inputs
,
n_unnamed_inputs
)
msg
=
get_info_on_inputs
(
named_inputs
,
n_unnamed_inputs
)
raise
TypeError
(
"Unknown input or state:
%
s.
%
s"
%
(
str
(
item
),
msg
))
raise
TypeError
(
"Unknown input or state: {}. {}"
.
format
(
str
(
item
),
msg
)
)
if
s
is
DUPLICATE
:
if
s
is
DUPLICATE
:
raise
TypeError
(
raise
TypeError
(
"Ambiguous name:
%
s - please check the "
"Ambiguous name:
%
s - please check the "
...
@@ -520,7 +517,7 @@ class Function(object):
...
@@ -520,7 +517,7 @@ class Function(object):
# this class is important in overriding the square-bracket notation:
# this class is important in overriding the square-bracket notation:
# fn.container[x]
# fn.container[x]
# self reference is available via the closure on the class
# self reference is available via the closure on the class
class
ContainerAttribute
(
object
)
:
class
ContainerAttribute
:
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
finder
[
item
]
return
finder
[
item
]
...
@@ -1065,10 +1062,10 @@ class Function(object):
...
@@ -1065,10 +1062,10 @@ class Function(object):
if
output_subset
is
None
:
if
output_subset
is
None
:
return
dict
(
zip
(
self
.
output_keys
,
outputs
))
return
dict
(
zip
(
self
.
output_keys
,
outputs
))
else
:
else
:
return
dict
(
return
{
(
self
.
output_keys
[
index
],
outputs
[
index
])
self
.
output_keys
[
index
]:
outputs
[
index
]
for
index
in
output_subset
for
index
in
output_subset
)
}
if
output_subset
is
None
:
if
output_subset
is
None
:
return
outputs
return
outputs
...
@@ -1201,13 +1198,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1201,13 +1198,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
assert
len
(
wrapped_inputs
)
==
len
(
fgraph
.
inputs
)
assert
len
(
wrapped_inputs
)
==
len
(
fgraph
.
inputs
)
assert
len
(
wrapped_outputs
)
==
len
(
fgraph
.
outputs
)
assert
len
(
wrapped_outputs
)
==
len
(
fgraph
.
outputs
)
reason
=
"insert_deepcopy"
reason
=
"insert_deepcopy"
updated_fgraph_inputs
=
set
(
updated_fgraph_inputs
=
{
[
fgraph_i
fgraph_i
for
i
,
fgraph_i
in
zip
(
wrapped_inputs
,
fgraph
.
inputs
)
for
i
,
fgraph_i
in
zip
(
wrapped_inputs
,
fgraph
.
inputs
)
if
getattr
(
i
,
"update"
,
False
)
if
getattr
(
i
,
"update"
,
False
)
]
}
)
# We can't use fgraph.inputs as this don't include Constant Value.
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs
=
gof
.
graph
.
inputs
(
fgraph
.
outputs
)
all_graph_inputs
=
gof
.
graph
.
inputs
(
fgraph
.
outputs
)
...
@@ -1286,7 +1281,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1286,7 +1281,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
NODEFAULT
=
[
"NODEFAULT"
]
NODEFAULT
=
[
"NODEFAULT"
]
class
FunctionMaker
(
object
)
:
class
FunctionMaker
:
"""
"""
`FunctionMaker` is the class to `create` `Function` instances.
`FunctionMaker` is the class to `create` `Function` instances.
...
@@ -2041,7 +2036,7 @@ def convert_function_input(input):
...
@@ -2041,7 +2036,7 @@ def convert_function_input(input):
orig
=
input
orig
=
input
if
not
input
:
if
not
input
:
raise
TypeError
(
"Nonsensical input specification:
%
s"
%
input
)
raise
TypeError
(
"Nonsensical input specification:
%
s"
%
input
)
if
isinstance
(
input
[
0
],
str
ing_types
):
if
isinstance
(
input
[
0
],
str
):
name
=
input
[
0
]
name
=
input
[
0
]
input
=
input
[
1
:]
input
=
input
[
1
:]
else
:
else
:
...
@@ -2133,7 +2128,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
...
@@ -2133,7 +2128,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
)
)
else
:
else
:
if
n_unnamed_inputs
==
0
:
if
n_unnamed_inputs
==
0
:
msg
=
"The function has
%
s named input
%
s (
%
s)."
%
(
msg
=
"The function has
{} named input{} ({})."
.
format
(
n_named_inputs
,
n_named_inputs
,
get_plural
(
n_named_inputs
),
get_plural
(
n_named_inputs
),
", "
.
join
(
named_inputs
),
", "
.
join
(
named_inputs
),
...
...
theano/compile/io.py
浏览文件 @
ffa5e139
...
@@ -6,8 +6,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`.
...
@@ -6,8 +6,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`.
import
logging
import
logging
from
six
import
string_types
from
theano
import
gof
from
theano
import
gof
from
.sharedvalue
import
SharedVariable
from
.sharedvalue
import
SharedVariable
...
@@ -18,7 +16,7 @@ _logger = logging.getLogger("theano.compile.io")
...
@@ -18,7 +16,7 @@ _logger = logging.getLogger("theano.compile.io")
__docformat__
=
"restructuredtext en"
__docformat__
=
"restructuredtext en"
class
SymbolicInput
(
object
)
:
class
SymbolicInput
:
"""
"""
Represents a symbolic input for use with function or FunctionMaker.
Represents a symbolic input for use with function or FunctionMaker.
...
@@ -79,7 +77,7 @@ class SymbolicInput(object):
...
@@ -79,7 +77,7 @@ class SymbolicInput(object):
else
:
else
:
self
.
name
=
name
self
.
name
=
name
if
self
.
name
is
not
None
and
not
isinstance
(
self
.
name
,
str
ing_types
):
if
self
.
name
is
not
None
and
not
isinstance
(
self
.
name
,
str
):
raise
TypeError
(
"name must be a string! (got:
%
s)"
%
self
.
name
)
raise
TypeError
(
"name must be a string! (got:
%
s)"
%
self
.
name
)
self
.
update
=
update
self
.
update
=
update
if
update
is
not
None
:
if
update
is
not
None
:
...
@@ -102,7 +100,7 @@ class SymbolicInput(object):
...
@@ -102,7 +100,7 @@ class SymbolicInput(object):
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
update
:
if
self
.
update
:
return
"In(
%
s ->
%
s)"
%
(
self
.
variable
,
self
.
update
)
return
"In(
{} -> {})"
.
format
(
self
.
variable
,
self
.
update
)
else
:
else
:
return
"In(
%
s)"
%
self
.
variable
return
"In(
%
s)"
%
self
.
variable
...
@@ -216,7 +214,7 @@ class In(SymbolicInput):
...
@@ -216,7 +214,7 @@ class In(SymbolicInput):
implicit
=
isinstance
(
value
,
gof
.
Container
)
or
isinstance
(
implicit
=
isinstance
(
value
,
gof
.
Container
)
or
isinstance
(
value
,
SharedVariable
value
,
SharedVariable
)
)
super
(
In
,
self
)
.
__init__
(
super
()
.
__init__
(
variable
=
variable
,
variable
=
variable
,
name
=
name
,
name
=
name
,
update
=
update
,
update
=
update
,
...
@@ -231,7 +229,7 @@ class In(SymbolicInput):
...
@@ -231,7 +229,7 @@ class In(SymbolicInput):
raise
TypeError
(
"An implicit input must be given a default value"
)
raise
TypeError
(
"An implicit input must be given a default value"
)
class
SymbolicOutput
(
object
)
:
class
SymbolicOutput
:
"""
"""
Represents a symbolic output for use with function or FunctionMaker.
Represents a symbolic output for use with function or FunctionMaker.
...
@@ -250,10 +248,10 @@ class SymbolicOutput(object):
...
@@ -250,10 +248,10 @@ class SymbolicOutput(object):
self
.
borrow
=
borrow
self
.
borrow
=
borrow
def
__str__
(
self
):
def
__str__
(
self
):
return
"Out(
%
s,
%
s)"
%
(
self
.
variable
,
self
.
borrow
)
return
"Out(
{},{})"
.
format
(
self
.
variable
,
self
.
borrow
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"Out(
%
s,
%
s)"
%
(
self
.
variable
,
self
.
borrow
)
return
"Out(
{},{})"
.
format
(
self
.
variable
,
self
.
borrow
)
Out
=
SymbolicOutput
Out
=
SymbolicOutput
theano/compile/mode.py
浏览文件 @
ffa5e139
...
@@ -6,8 +6,6 @@ WRITEME
...
@@ -6,8 +6,6 @@ WRITEME
import
logging
import
logging
import
warnings
import
warnings
from
six
import
string_types
import
theano
import
theano
import
theano.gof.vm
import
theano.gof.vm
from
theano
import
config
,
gof
from
theano
import
config
,
gof
...
@@ -132,7 +130,7 @@ class AddDestroyHandler(gof.Optimizer):
...
@@ -132,7 +130,7 @@ class AddDestroyHandler(gof.Optimizer):
)
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
super
(
AddDestroyHandler
,
self
)
.
add_requirements
(
fgraph
)
super
()
.
add_requirements
(
fgraph
)
fgraph
.
attach_feature
(
gof
.
DestroyHandler
())
fgraph
.
attach_feature
(
gof
.
DestroyHandler
())
...
@@ -145,7 +143,7 @@ class AddFeatureOptimizer(gof.Optimizer):
...
@@ -145,7 +143,7 @@ class AddFeatureOptimizer(gof.Optimizer):
self
.
feature
=
feature
self
.
feature
=
feature
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
super
(
AddFeatureOptimizer
,
self
)
.
add_requirements
(
fgraph
)
super
()
.
add_requirements
(
fgraph
)
fgraph
.
attach_feature
(
self
.
feature
)
fgraph
.
attach_feature
(
self
.
feature
)
...
@@ -259,7 +257,7 @@ optdb.register("CheckStackTrace", gof.CheckStackTraceOptimization(), -1, *_tags)
...
@@ -259,7 +257,7 @@ optdb.register("CheckStackTrace", gof.CheckStackTraceOptimization(), -1, *_tags)
del
_tags
del
_tags
class
Mode
(
object
)
:
class
Mode
:
"""
"""
The Mode represents a way to optimize and then link a computation graph.
The Mode represents a way to optimize and then link a computation graph.
...
@@ -303,10 +301,10 @@ class Mode(object):
...
@@ -303,10 +301,10 @@ class Mode(object):
linker
,
optimizer
=
state
linker
,
optimizer
=
state
self
.
provided_linker
=
linker
self
.
provided_linker
=
linker
self
.
provided_optimizer
=
optimizer
self
.
provided_optimizer
=
optimizer
if
isinstance
(
linker
,
str
ing_types
)
or
linker
is
None
:
if
isinstance
(
linker
,
str
)
or
linker
is
None
:
linker
=
predefined_linkers
[
linker
]
linker
=
predefined_linkers
[
linker
]
self
.
linker
=
linker
self
.
linker
=
linker
if
isinstance
(
optimizer
,
str
ing_types
)
or
optimizer
is
None
:
if
isinstance
(
optimizer
,
str
)
or
optimizer
is
None
:
optimizer
=
predefined_optimizers
[
optimizer
]
optimizer
=
predefined_optimizers
[
optimizer
]
if
isinstance
(
optimizer
,
gof
.
Query
):
if
isinstance
(
optimizer
,
gof
.
Query
):
self
.
provided_optimizer
=
optimizer
self
.
provided_optimizer
=
optimizer
...
@@ -315,7 +313,7 @@ class Mode(object):
...
@@ -315,7 +313,7 @@ class Mode(object):
self
.
fn_time
=
0
self
.
fn_time
=
0
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s(linker =
%
s, optimizer =
%
s)"
%
(
return
"
{}(linker = {}, optimizer = {})"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
self
.
provided_linker
,
self
.
provided_linker
,
self
.
provided_optimizer
,
self
.
provided_optimizer
,
...
@@ -330,9 +328,9 @@ class Mode(object):
...
@@ -330,9 +328,9 @@ class Mode(object):
optimizer
=
property
(
__get_optimizer
)
optimizer
=
property
(
__get_optimizer
)
def
get_linker_optimizer
(
self
,
linker
,
optimizer
):
def
get_linker_optimizer
(
self
,
linker
,
optimizer
):
if
isinstance
(
linker
,
str
ing_types
)
or
linker
is
None
:
if
isinstance
(
linker
,
str
)
or
linker
is
None
:
linker
=
predefined_linkers
[
linker
]
linker
=
predefined_linkers
[
linker
]
if
isinstance
(
optimizer
,
str
ing_types
)
or
optimizer
is
None
:
if
isinstance
(
optimizer
,
str
)
or
optimizer
is
None
:
optimizer
=
predefined_optimizers
[
optimizer
]
optimizer
=
predefined_optimizers
[
optimizer
]
return
(
linker
,
optimizer
)
return
(
linker
,
optimizer
)
...
@@ -432,7 +430,7 @@ def get_mode(orig_string):
...
@@ -432,7 +430,7 @@ def get_mode(orig_string):
string
=
config
.
mode
string
=
config
.
mode
else
:
else
:
string
=
orig_string
string
=
orig_string
if
not
isinstance
(
string
,
str
ing_types
):
if
not
isinstance
(
string
,
str
):
return
string
# it is hopefully already a mode...
return
string
# it is hopefully already a mode...
global
instantiated_default_mode
global
instantiated_default_mode
...
...
theano/compile/monitormode.py
浏览文件 @
ffa5e139
...
@@ -52,17 +52,17 @@ class MonitorMode(Mode):
...
@@ -52,17 +52,17 @@ class MonitorMode(Mode):
linker
,
linker
,
)
)
super
(
MonitorMode
,
self
)
.
__init__
(
wrap_linker
,
optimizer
=
optimizer
)
super
()
.
__init__
(
wrap_linker
,
optimizer
=
optimizer
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
lnk
,
opt
=
super
(
MonitorMode
,
self
)
.
__getstate__
()
lnk
,
opt
=
super
()
.
__getstate__
()
return
(
lnk
,
opt
,
self
.
pre_func
,
self
.
post_func
)
return
(
lnk
,
opt
,
self
.
pre_func
,
self
.
post_func
)
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
lnk
,
opt
,
pre_func
,
post_func
=
state
lnk
,
opt
,
pre_func
,
post_func
=
state
self
.
pre_func
=
pre_func
self
.
pre_func
=
pre_func
self
.
post_func
=
post_func
self
.
post_func
=
post_func
super
(
MonitorMode
,
self
)
.
__setstate__
((
lnk
,
opt
))
super
()
.
__setstate__
((
lnk
,
opt
))
def
eval
(
self
,
i
,
node
,
fn
):
def
eval
(
self
,
i
,
node
,
fn
):
"""
"""
...
...
theano/compile/nanguardmode.py
浏览文件 @
ffa5e139
...
@@ -295,6 +295,4 @@ class NanGuardMode(Mode):
...
@@ -295,6 +295,4 @@ class NanGuardMode(Mode):
wrap_linker
=
theano
.
gof
.
vm
.
VM_Linker
(
wrap_linker
=
theano
.
gof
.
vm
.
VM_Linker
(
callback
=
nan_check
,
callback_input
=
nan_check_input
callback
=
nan_check
,
callback_input
=
nan_check_input
)
)
super
(
NanGuardMode
,
self
)
.
__init__
(
super
()
.
__init__
(
wrap_linker
,
optimizer
=
self
.
provided_optimizer
)
wrap_linker
,
optimizer
=
self
.
provided_optimizer
)
theano/compile/ops.py
浏览文件 @
ffa5e139
...
@@ -11,7 +11,6 @@ from collections import OrderedDict
...
@@ -11,7 +11,6 @@ from collections import OrderedDict
import
numpy
as
np
import
numpy
as
np
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
from
six
import
integer_types
import
theano
import
theano
from
theano.gof
import
Apply
,
Op
,
ParamsType
,
Variable
from
theano.gof
import
Apply
,
Op
,
ParamsType
,
Variable
...
@@ -71,7 +70,7 @@ class ViewOp(Op):
...
@@ -71,7 +70,7 @@ class ViewOp(Op):
return
code
%
locals
()
return
code
%
locals
()
# Else, no C code
# Else, no C code
return
super
(
ViewOp
,
self
)
.
c_code
(
node
,
nodename
,
inp
,
out
,
sub
)
return
super
()
.
c_code
(
node
,
nodename
,
inp
,
out
,
sub
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
version
=
[]
version
=
[]
...
@@ -206,7 +205,7 @@ class DeepCopyOp(Op):
...
@@ -206,7 +205,7 @@ class DeepCopyOp(Op):
return
code
%
locals
()
return
code
%
locals
()
# Else, no C code
# Else, no C code
return
super
(
DeepCopyOp
,
self
)
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
return
super
()
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
deep_copy_op
=
DeepCopyOp
()
deep_copy_op
=
DeepCopyOp
()
...
@@ -296,7 +295,7 @@ class Shape(Op):
...
@@ -296,7 +295,7 @@ class Shape(Op):
return
code
%
locals
()
return
code
%
locals
()
# Else, no C code
# Else, no C code
return
super
(
Shape
,
self
)
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
return
super
()
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
version
=
[]
version
=
[]
...
@@ -423,7 +422,7 @@ class Shape_i(Op):
...
@@ -423,7 +422,7 @@ class Shape_i(Op):
return
(
check_input
+
code
)
%
locals
()
return
(
check_input
+
code
)
%
locals
()
# Else, no C code
# Else, no C code
return
super
(
Shape_i
,
self
)
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
return
super
()
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
def
infer_shape
(
self
,
node
,
input_shapes
):
def
infer_shape
(
self
,
node
,
input_shapes
):
return
[()]
return
[()]
...
@@ -583,7 +582,7 @@ class FromFunctionOp(Op):
...
@@ -583,7 +582,7 @@ class FromFunctionOp(Op):
obj
=
load_back
(
mod
,
name
)
obj
=
load_back
(
mod
,
name
)
except
(
ImportError
,
KeyError
,
AttributeError
):
except
(
ImportError
,
KeyError
,
AttributeError
):
raise
pickle
.
PicklingError
(
raise
pickle
.
PicklingError
(
"Can't pickle as_op(), not found as
%
s.
%
s"
%
(
mod
,
name
)
"Can't pickle as_op(), not found as
{}.{}"
.
format
(
mod
,
name
)
)
)
else
:
else
:
if
obj
is
not
self
:
if
obj
is
not
self
:
...
@@ -699,7 +698,7 @@ class Rebroadcast(Op):
...
@@ -699,7 +698,7 @@ class Rebroadcast(Op):
items
=
sorted
(
axis
)
items
=
sorted
(
axis
)
self
.
axis
=
OrderedDict
(
items
)
self
.
axis
=
OrderedDict
(
items
)
for
axis
,
broad
in
self
.
axis
.
items
():
for
axis
,
broad
in
self
.
axis
.
items
():
if
not
isinstance
(
axis
,
(
np
.
integer
,
int
eger_types
)):
if
not
isinstance
(
axis
,
(
np
.
integer
,
int
)):
raise
TypeError
(
raise
TypeError
(
"Rebroadcast needs integer axes. "
"Got {}"
.
format
(
axis
)
"Rebroadcast needs integer axes. "
"Got {}"
.
format
(
axis
)
)
)
...
@@ -723,7 +722,7 @@ class Rebroadcast(Op):
...
@@ -723,7 +722,7 @@ class Rebroadcast(Op):
broadcast_pattern
=
[
"?"
for
i
in
range
(
1
+
max
(
self
.
axis
.
keys
()))]
broadcast_pattern
=
[
"?"
for
i
in
range
(
1
+
max
(
self
.
axis
.
keys
()))]
for
k
,
v
in
self
.
axis
.
items
():
for
k
,
v
in
self
.
axis
.
items
():
broadcast_pattern
[
k
]
=
str
(
int
(
v
))
broadcast_pattern
[
k
]
=
str
(
int
(
v
))
return
"
%
s{
%
s}"
%
(
self
.
__class__
.
__name__
,
","
.
join
(
broadcast_pattern
))
return
"
{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
","
.
join
(
broadcast_pattern
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
if
self
.
axis
.
keys
()
and
(
x
.
ndim
<=
max
(
self
.
axis
.
keys
())):
if
self
.
axis
.
keys
()
and
(
x
.
ndim
<=
max
(
self
.
axis
.
keys
())):
...
@@ -797,7 +796,7 @@ class Rebroadcast(Op):
...
@@ -797,7 +796,7 @@ class Rebroadcast(Op):
"""
"""
%
locals
()
%
locals
()
)
)
return
super
(
Rebroadcast
,
self
)
.
c_code
(
node
,
nodename
,
inp
,
out
,
sub
)
return
super
()
.
c_code
(
node
,
nodename
,
inp
,
out
,
sub
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
version
=
[]
version
=
[]
...
@@ -929,7 +928,7 @@ class SpecifyShape(Op):
...
@@ -929,7 +928,7 @@ class SpecifyShape(Op):
_
,
_
,
support_code
=
self
.
c_code_and_version
[
itype
]
_
,
_
,
support_code
=
self
.
c_code_and_version
[
itype
]
if
support_code
:
if
support_code
:
return
support_code
return
support_code
return
super
(
SpecifyShape
,
self
)
.
c_support_code_apply
(
node
,
name
)
return
super
()
.
c_support_code_apply
(
node
,
name
)
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
name
,
inames
,
onames
,
sub
):
iname
,
shape
=
inames
iname
,
shape
=
inames
...
@@ -941,7 +940,7 @@ class SpecifyShape(Op):
...
@@ -941,7 +940,7 @@ class SpecifyShape(Op):
code
,
version
,
_
=
self
.
c_code_and_version
[
itype
]
code
,
version
,
_
=
self
.
c_code_and_version
[
itype
]
return
code
%
locals
()
return
code
%
locals
()
return
super
(
SpecifyShape
,
self
)
.
c_code
(
node
,
node
,
inames
,
onames
,
sub
)
return
super
()
.
c_code
(
node
,
node
,
inames
,
onames
,
sub
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
version
=
[]
version
=
[]
...
...
theano/compile/pfunc.py
浏览文件 @
ffa5e139
...
@@ -300,7 +300,7 @@ class Param(In):
...
@@ -300,7 +300,7 @@ class Param(In):
" by theano.In(value=N)"
,
" by theano.In(value=N)"
,
stacklevel
=
2
,
stacklevel
=
2
,
)
)
super
(
Param
,
self
)
.
__init__
(
super
()
.
__init__
(
variable
,
variable
,
name
=
name
,
name
=
name
,
value
=
default
,
value
=
default
,
...
@@ -447,13 +447,11 @@ def pfunc(
...
@@ -447,13 +447,11 @@ def pfunc(
if
v
in
in_variables
[(
i
+
1
)
:]:
if
v
in
in_variables
[(
i
+
1
)
:]:
dup_v_i
=
in_variables
.
index
(
v
,
(
i
+
1
))
dup_v_i
=
in_variables
.
index
(
v
,
(
i
+
1
))
raise
UnusedInputError
(
raise
UnusedInputError
(
(
"Variable
%
s is used twice in inputs to theano.function, "
"Variable
%
s is used twice in inputs to theano.function, "
"at indices
%
i and
%
i. This would result in values "
"at indices
%
i and
%
i. This would result in values "
"provided for it being ignored. Please do not duplicate "
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list."
%
(
v
,
i
,
dup_v_i
)
"variables in the inputs list."
%
(
v
,
i
,
dup_v_i
)
)
)
)
# Check that we are not using `givens` to replace input variables, because
# Check that we are not using `givens` to replace input variables, because
# this typically does nothing, contrary to what one may expect.
# this typically does nothing, contrary to what one may expect.
...
...
theano/compile/profiling.py
浏览文件 @
ffa5e139
...
@@ -181,7 +181,7 @@ def register_profiler_printer(fct):
...
@@ -181,7 +181,7 @@ def register_profiler_printer(fct):
return
fct
return
fct
class
ProfileStats
(
object
)
:
class
ProfileStats
:
"""
"""
Object to store runtime and memory profiling information for all of
Object to store runtime and memory profiling information for all of
...
@@ -851,7 +851,7 @@ class ProfileStats(object):
...
@@ -851,7 +851,7 @@ class ProfileStats(object):
for
node
,
t
in
sorted
(
for
node
,
t
in
sorted
(
self
.
linker_make_thunk_time
.
items
(),
key
=
operator
.
itemgetter
(
1
)
self
.
linker_make_thunk_time
.
items
(),
key
=
operator
.
itemgetter
(
1
)
)[::
-
1
][:
5
]:
)[::
-
1
][:
5
]:
print
(
" Node
%
s time
%
es"
%
(
node
,
t
),
file
=
file
)
print
(
" Node
{} time {:e}s"
.
format
(
node
,
t
),
file
=
file
)
print
(
""
,
file
=
file
)
print
(
""
,
file
=
file
)
# The validation time is a subset of optimizer_time
# The validation time is a subset of optimizer_time
...
@@ -1071,7 +1071,7 @@ class ProfileStats(object):
...
@@ -1071,7 +1071,7 @@ class ProfileStats(object):
mem_bound
=
np
.
inf
mem_bound
=
np
.
inf
# This take only the inputs/outputs dependencies.
# This take only the inputs/outputs dependencies.
dependencies
=
fgraph
.
profile
.
dependencies
dependencies
=
fgraph
.
profile
.
dependencies
done_set
=
set
(
[]
)
done_set
=
set
()
done_dict
=
{}
done_dict
=
{}
# Initial compute_map which is used to check if a node is valid
# Initial compute_map which is used to check if a node is valid
...
@@ -1451,7 +1451,10 @@ class ProfileStats(object):
...
@@ -1451,7 +1451,10 @@ class ProfileStats(object):
else
:
else
:
size
=
"
%10
s"
%
"Unknown"
size
=
"
%10
s"
%
"Unknown"
print
(
"
%
s
%
s
%
s
%
s"
%
(
size
,
shapes
,
" "
.
join
(
code
),
node
),
file
=
file
)
print
(
" {} {} {} {}"
.
format
(
size
,
shapes
,
" "
.
join
(
code
),
node
),
file
=
file
,
)
sum_remaining
=
sum
(
size
for
_
,
size
in
items
[
N
:])
sum_remaining
=
sum
(
size
for
_
,
size
in
items
[
N
:])
size_sum_dense
=
sum
(
node_mem
.
values
())
size_sum_dense
=
sum
(
node_mem
.
values
())
...
@@ -1499,7 +1502,7 @@ class ProfileStats(object):
...
@@ -1499,7 +1502,7 @@ class ProfileStats(object):
file
=
file
,
file
=
file
,
)
)
if
config
.
profiling
.
debugprint
:
if
config
.
profiling
.
debugprint
:
fcts
=
set
([
n
.
fgraph
for
n
in
self
.
apply_time
.
keys
()])
fcts
=
{
n
.
fgraph
for
n
in
self
.
apply_time
.
keys
()}
theano
.
printing
.
debugprint
(
fcts
,
print_type
=
True
)
theano
.
printing
.
debugprint
(
fcts
,
print_type
=
True
)
if
self
.
variable_shape
or
self
.
variable_strides
:
if
self
.
variable_shape
or
self
.
variable_strides
:
self
.
summary_memory
(
file
,
n_apply_to_print
)
self
.
summary_memory
(
file
,
n_apply_to_print
)
...
@@ -1686,10 +1689,7 @@ class ProfileStats(object):
...
@@ -1686,10 +1689,7 @@ class ProfileStats(object):
# tip 6
# tip 6
for
a
in
self
.
apply_time
:
for
a
in
self
.
apply_time
:
node
=
a
node
=
a
if
(
if
isinstance
(
node
.
op
,
T
.
Dot
)
and
len
({
i
.
dtype
for
i
in
node
.
inputs
})
!=
1
:
isinstance
(
node
.
op
,
T
.
Dot
)
and
len
(
set
(
i
.
dtype
for
i
in
node
.
inputs
))
!=
1
):
print
(
print
(
" - You have a dot operation that has different dtype "
" - You have a dot operation that has different dtype "
" for inputs (
%
s). Make sure that the inputs have same "
" for inputs (
%
s). Make sure that the inputs have same "
...
@@ -1742,7 +1742,7 @@ class ScanProfileStats(ProfileStats):
...
@@ -1742,7 +1742,7 @@ class ScanProfileStats(ProfileStats):
call_time
=
0.0
call_time
=
0.0
def
__init__
(
self
,
atexit_print
=
True
,
name
=
None
,
**
kwargs
):
def
__init__
(
self
,
atexit_print
=
True
,
name
=
None
,
**
kwargs
):
super
(
ScanProfileStats
,
self
)
.
__init__
(
atexit_print
,
**
kwargs
)
super
()
.
__init__
(
atexit_print
,
**
kwargs
)
self
.
name
=
name
self
.
name
=
name
def
summary_globals
(
self
,
file
):
def
summary_globals
(
self
,
file
):
...
...
theano/compile/sharedvalue.py
浏览文件 @
ffa5e139
...
@@ -67,9 +67,7 @@ class SharedVariable(Variable):
...
@@ -67,9 +67,7 @@ class SharedVariable(Variable):
# or the "no_default_updates" list passed to "function" contains it.
# or the "no_default_updates" list passed to "function" contains it.
def
__init__
(
self
,
name
,
type
,
value
,
strict
,
allow_downcast
=
None
,
container
=
None
):
def
__init__
(
self
,
name
,
type
,
value
,
strict
,
allow_downcast
=
None
,
container
=
None
):
super
(
SharedVariable
,
self
)
.
__init__
(
super
()
.
__init__
(
type
=
type
,
name
=
name
,
owner
=
None
,
index
=
None
)
type
=
type
,
name
=
name
,
owner
=
None
,
index
=
None
)
if
container
is
not
None
:
if
container
is
not
None
:
self
.
container
=
container
self
.
container
=
container
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论