Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2fe4b0b8
提交
2fe4b0b8
authored
11月 11, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1594 from nouiz/faster_opt
[MRG] Fix test python 2.4 and faster opt
上级
f07e4644
9c4cd7dc
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
113 行增加
和
86 行删除
+113
-86
optimization.txt
doc/extending/optimization.txt
+3
-1
fg.py
theano/gof/fg.py
+4
-2
graph.py
theano/gof/graph.py
+6
-8
opt.py
theano/gof/opt.py
+56
-30
toolbox.py
theano/gof/toolbox.py
+6
-5
nnet.py
theano/tensor/nnet/nnet.py
+1
-1
opt.py
theano/tensor/opt.py
+19
-0
opt_uncanonicalize.py
theano/tensor/opt_uncanonicalize.py
+17
-38
var.py
theano/tensor/var.py
+1
-1
没有找到文件。
doc/extending/optimization.txt
浏览文件 @
2fe4b0b8
...
...
@@ -283,7 +283,9 @@ The local version of the above code would be the following:
The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's
outputs with must be returned.
outputs with must be returned. This list must have the same length as
node.ouputs. If one of node.outputs don't have clients(it is not used
in the graph), you can put None in the returned list to remove it.
In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. Basically, a :ref:`navigator` is a global
...
...
theano/gof/fg.py
浏览文件 @
2fe4b0b8
...
...
@@ -93,6 +93,7 @@ class FunctionGraph(utils.object2):
inputs
,
outputs
=
graph
.
clone
(
inputs
,
outputs
)
self
.
execute_callbacks_time
=
0
self
.
execute_callbacks_times
=
{}
if
features
is
None
:
features
=
[]
...
...
@@ -507,7 +508,7 @@ class FunctionGraph(utils.object2):
attach
(
self
)
except
toolbox
.
AlreadyThere
:
return
self
.
execute_callbacks_times
.
setdefault
(
feature
,
0
)
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
...
...
@@ -549,8 +550,9 @@ class FunctionGraph(utils.object2):
# try; the AttributeError reall must come from feature.${name}
# not existing
continue
tf0
=
time
.
time
()
fn
(
self
,
*
args
,
**
kwargs
)
self
.
execute_callbacks_times
[
feature
]
+=
time
.
time
()
-
tf0
self
.
execute_callbacks_time
+=
time
.
time
()
-
t0
def
collect_callbacks
(
self
,
name
,
*
args
):
...
...
theano/gof/graph.py
浏览文件 @
2fe4b0b8
...
...
@@ -495,14 +495,14 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
:param start: search from these nodes
:type expand: callable
:param expand:
when we get to a node, add expand(node) to the list of nodes to visit.
This function
should return a list, or None
when we get to a node, add expand(node) to the list of nodes to visit.
This function
should return a list, or None
:rtype: list of `Variable` or `Apply` instances (depends on `expend`)
:return: the list of nodes in order of traversal.
:note:
a node will appear at most once in the return value, even if it
appears multiple times
in the start parameter.
a node will appear at most once in the return value, even if it
appears multiple times
in the start parameter.
:postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty.
...
...
@@ -549,9 +549,7 @@ def ancestors(variable_list, blockers=None):
"""
def
expand
(
r
):
if
r
.
owner
and
(
not
blockers
or
r
not
in
blockers
):
l
=
list
(
r
.
owner
.
inputs
)
l
.
reverse
()
return
l
return
reversed
(
r
.
owner
.
inputs
)
dfs_variables
=
stack_search
(
deque
(
variable_list
),
expand
,
'dfs'
)
return
dfs_variables
...
...
@@ -801,7 +799,7 @@ def io_toposort(inputs, outputs, orderings=None):
if
isinstance
(
obj
,
Variable
):
if
obj
.
owner
:
rval
=
[
obj
.
owner
]
if
isinstance
(
obj
,
Apply
):
el
if
isinstance
(
obj
,
Apply
):
rval
=
list
(
obj
.
inputs
)
rval
.
extend
(
orderings
.
get
(
obj
,
[]))
else
:
...
...
theano/gof/opt.py
浏览文件 @
2fe4b0b8
...
...
@@ -514,7 +514,8 @@ class MergeFeature(object):
continue
inputs_match
=
all
(
node_in
is
cand_in
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
for
node_in
,
cand_in
in
zip
(
node
.
inputs
,
candidate
.
inputs
))
if
inputs_match
and
node
.
op
==
candidate
.
op
:
if
(
node
,
candidate
)
in
self
.
blacklist
:
# They were already tried, and there was an error
...
...
@@ -566,6 +567,8 @@ class MergeOptimizer(Optimizer):
if
fgraph
.
profile
:
validate_before
=
fgraph
.
profile
.
validate_time
callback_before
=
fgraph
.
execute_callbacks_time
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
nb_merged
=
0
nb_constant
=
0
while
sched
:
...
...
@@ -589,20 +592,28 @@ class MergeOptimizer(Optimizer):
if
fgraph
.
profile
:
validate_time
=
fgraph
.
profile
.
validate_time
-
validate_before
callback_time
=
fgraph
.
execute_callbacks_time
-
callback_before
callbacks_time
=
{}
for
k
,
v
in
fgraph
.
execute_callbacks_times
.
iteritems
():
if
k
in
callbacks_before
:
callbacks_time
[
k
]
=
v
-
callbacks_before
[
k
]
else
:
callbacks_time
[
k
]
=
v
else
:
validate_time
=
None
callback_time
=
None
callbacks_time
=
{}
# clear blacklist
fgraph
.
merge_feature
.
blacklist
=
[]
return
(
nb_fail
,
time
.
time
()
-
t0
,
validate_time
,
callback_time
,
nb_merged
,
nb_constant
)
callback_time
,
callbacks_time
,
nb_merged
,
nb_constant
)
def
__str__
(
self
):
return
self
.
__class__
.
__name__
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
nb_fail
,
replace_time
,
validate_time
,
callback_time
,
nb_merged
,
nb_constant
=
prof
(
nb_fail
,
replace_time
,
validate_time
,
callback_time
,
callbacks_time
,
nb_merged
,
nb_constant
)
=
prof
blanc
=
(
' '
*
level
)
print
>>
stream
,
blanc
,
"MergeOptimizer"
...
...
@@ -610,6 +621,7 @@ class MergeOptimizer(Optimizer):
print
>>
stream
,
blanc
,
" replace_time"
,
replace_time
print
>>
stream
,
blanc
,
" validate_time"
,
validate_time
print
>>
stream
,
blanc
,
" callback_time"
,
callback_time
print
>>
stream
,
blanc
,
" callback_times"
,
callbacks_time
print
>>
stream
,
blanc
,
" nb_merged"
,
nb_merged
print
>>
stream
,
blanc
,
" nb_constant"
,
nb_constant
...
...
@@ -740,7 +752,7 @@ class LocalOptimizer(object):
"""
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
type
(
self
),
self
.
__class__
.
__name__
)
def
add_requirements
(
self
,
fgraph
):
"""
...
...
@@ -770,7 +782,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
'<FromFunctionLocalOptimizer instance>'
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
...
...
@@ -800,8 +812,8 @@ class LocalOptGroup(LocalOptimizer):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
(
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
(
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
def
transform
(
self
,
node
):
for
opt
in
self
.
opts
:
...
...
@@ -957,7 +969,7 @@ class PatternSub(LocalOptimizer):
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
...
...
@@ -978,10 +990,10 @@ class PatternSub(LocalOptimizer):
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
else
:
raise
TypeError
(
"The pattern to search for must start with "
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
)
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
)
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
...
...
@@ -1024,7 +1036,7 @@ class PatternSub(LocalOptimizer):
#TODO: Not sure how to handle multiple_clients flag
###print 'retrying match', pattern, expr_equiv
return
match
(
pattern
,
expr_equiv
,
u
,
allow_multiple_clients
=
allow_multiple_clients
)
allow_multiple_clients
=
allow_multiple_clients
)
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
expr
.
owner
is
None
:
...
...
@@ -1044,8 +1056,8 @@ class PatternSub(LocalOptimizer):
real_pattern
=
pattern
[
'pattern'
]
except
KeyError
:
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
if
constraint
(
expr
):
return
match
(
real_pattern
,
expr
,
u
,
...
...
@@ -1275,7 +1287,8 @@ class NavigatorOptimizer(Optimizer):
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
return
False
else
:
raise
...
...
@@ -1283,14 +1296,20 @@ class NavigatorOptimizer(Optimizer):
return
False
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. '
'Expected list or tuple.'
%
lopt
)
'Expected list or tuple.'
%
lopt
)
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
%
lopt
)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
):
if
rnew
is
None
and
len
(
r
.
clients
)
>
0
:
raise
ValueError
(
"A local optimizer tried to remove a Variable that is used"
)
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
if
rnew
is
not
r
]
if
rnew
is
not
r
and
rnew
is
not
None
]
if
len
(
repl_pairs
)
==
0
:
return
False
try
:
...
...
@@ -1319,19 +1338,19 @@ class NavigatorOptimizer(Optimizer):
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
depth
=
(
depth
-
1
))
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
failure_callback
=
None
):
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
self
.
order
=
order
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
failure_callback
)
def
apply
(
self
,
fgraph
,
start_from
=
None
):
if
start_from
is
None
:
...
...
@@ -1397,12 +1416,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
'op_key'
):
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have "
"an 'op_key' method."
)
"an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
failure_callback
)
def
apply
(
self
,
fgraph
):
op
=
self
.
local_opt
.
op_key
()
...
...
@@ -1513,6 +1532,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort
=
False
opt_name
=
None
global_process_count
=
{}
start_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
...
...
@@ -1597,13 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
loop_process_count
.
append
(
process_count
)
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
end_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
if
max_use_abort
:
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
+
". You can safely raise the current threshold of "
+
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
config
.
optdb
.
max_use_ratio
)
return
(
self
,
loop_timing
,
loop_process_count
,
max_nb_nodes
,
return
(
self
,
loop_timing
,
loop_process_count
,
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
...
...
@@ -1613,19 +1636,22 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if
depth
!=
0
:
for
lopt
in
self
.
local_optimizers
:
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
depth
=
(
depth
-
1
))
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
(
opt
,
loop_timing
,
loop_process_count
,
max_nb_nodes
,
(
opt
,
loop_timing
,
loop_process_count
,
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
=
prof
blanc
=
(
' '
*
level
)
print
>>
stream
,
blanc
,
"EquilibriumOptimizer"
,
print
>>
stream
,
blanc
,
getattr
(
opt
,
"name"
,
getattr
(
opt
,
"__name__"
,
""
))
print
>>
stream
,
blanc
,
" time
%.3
fs for
%
d passes,
%
d nodes max"
%
(
sum
(
loop_timing
),
len
(
loop_timing
),
max_nb_nodes
)
print
>>
stream
,
blanc
,
" time
%.3
fs for
%
d passes"
%
(
sum
(
loop_timing
),
len
(
loop_timing
))
print
>>
stream
,
blanc
,
" nb nodes (start, end, max)
%
d
%
d
%
d"
%
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
)
print
>>
stream
,
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
io_toposort_timing
)
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
local_optimizers
])
...
...
theano/gof/toolbox.py
浏览文件 @
2fe4b0b8
...
...
@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
raise
ReplacementDidntRemovedError
()
class
NodeFinder
(
dict
,
Bookkeeper
):
class
NodeFinder
(
Bookkeeper
):
def
__init__
(
self
):
self
.
fgraph
=
None
self
.
d
=
{}
def
on_attach
(
self
,
fgraph
):
if
self
.
fgraph
is
not
None
:
...
...
@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
try
:
self
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
self
.
d
.
setdefault
(
node
.
op
,
[])
.
append
(
node
)
except
TypeError
:
# node.op is unhashable
return
except
Exception
,
e
:
...
...
@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
try
:
nodes
=
self
[
node
.
op
]
nodes
=
self
.
d
[
node
.
op
]
except
TypeError
:
# node.op is unhashable
return
nodes
.
remove
(
node
)
if
not
nodes
:
del
self
[
node
.
op
]
del
self
.
d
[
node
.
op
]
def
query
(
self
,
fgraph
,
op
):
try
:
all
=
self
.
get
(
op
,
[])
all
=
self
.
d
.
get
(
op
,
[])
except
TypeError
:
raise
TypeError
(
"
%
s in unhashable and cannot be queried by the"
" optimizer"
%
op
)
...
...
theano/tensor/nnet/nnet.py
浏览文件 @
2fe4b0b8
...
...
@@ -353,7 +353,7 @@ class Softmax(gof.Op):
x
=
tensor
.
as_tensor_variable
(
x
)
if
x
.
type
.
ndim
not
in
(
1
,
2
)
\
or
x
.
type
.
dtype
not
in
tensor
.
float_dtypes
:
raise
ValueError
(
'x must be 1-d or 2-d tensor of floats
'
)
raise
ValueError
(
'x must be 1-d or 2-d tensor of floats
. Got '
,
x
.
type
)
if
x
.
ndim
==
1
:
x
=
tensor
.
shape_padleft
(
x
,
n_ones
=
1
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
...
theano/tensor/opt.py
浏览文件 @
2fe4b0b8
...
...
@@ -915,6 +915,13 @@ class ShapeFeature(object):
# If no info is known on r's shape, use other_shape
self
.
set_shape
(
r
,
other_shape
)
return
if
(
other_r
.
owner
and
r
.
owner
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
op
==
r
.
owner
.
op
):
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# ancestors() less frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape
=
[]
...
...
@@ -928,6 +935,18 @@ class ShapeFeature(object):
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
r_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
other_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
other_shape
[
i
])
elif
other_shape
[
i
]
==
r_shape
[
i
]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape
.
append
(
r_shape
[
i
])
elif
r_shape
[
i
]
in
theano
.
gof
.
graph
.
ancestors
([
other_shape
[
i
]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
...
...
theano/tensor/opt_uncanonicalize.py
浏览文件 @
2fe4b0b8
...
...
@@ -26,54 +26,33 @@ import logging
_logger
=
logging
.
getLogger
(
'theano.tensor.opt'
)
from
theano
import
gof
from
theano.compat.python2x
import
deque
from
theano.tensor.elemwise
import
CAReduce
from
theano.tensor
import
basic
as
T
from
theano.gof.opt
import
Optimizer
from
theano.gof
import
InconsistencyError
,
toolbox
from
theano.tensor.basic
import
(
get_scalar_constant_value
,
NotScalarConstantError
)
from
theano.tensor.opt
import
register_uncanonicalize
from
theano
import
scalar
as
scal
class
MaxAndArgmaxOptimizer
(
Optimizer
):
"""Replace MaxAndArgmax by CAReduce when the argmax is not used
This is faster as MaxAndArgmax don't have c code and execute it
in two pass.
@register_uncanonicalize
@gof.local_optimizer
([
T
.
_max_and_argmax
])
def
local_max_and_argmax
(
node
):
"""
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
did_something
=
True
while
did_something
:
nodelist
=
fgraph
.
toposort
()
did_something
=
False
for
node
in
nodelist
:
if
node
.
op
==
T
.
_max_and_argmax
:
if
len
(
node
.
outputs
[
1
]
.
clients
)
==
0
:
try
:
axis
=
get_scalar_constant_value
(
node
.
inputs
[
1
])
except
NotScalarConstantError
:
return
False
new
=
CAReduce
(
scal
.
maximum
,
axis
)(
node
.
inputs
[
0
])
try
:
fgraph
.
replace_all_validate
(
((
node
.
outputs
[
0
],
new
),),
reason
=
self
.
__class__
.
__name__
)
did_something
=
True
break
except
InconsistencyError
,
e
:
pass
register_uncanonicalize
(
MaxAndArgmaxOptimizer
(),
name
=
'MaxAndArgmaxOptimizer'
)
If we don't use the argmax, change it to a max only.
"""
if
node
.
op
==
T
.
_max_and_argmax
:
if
len
(
node
.
outputs
[
1
]
.
clients
)
==
0
:
#MaxAndArgmax support variable axis,
#but CAReduce support only constant axis.
try
:
axis
=
get_scalar_constant_value
(
node
.
inputs
[
1
])
except
NotScalarConstantError
:
return
False
new
=
CAReduce
(
scal
.
maximum
,
axis
)(
node
.
inputs
[
0
])
return
[
new
,
None
]
@register_uncanonicalize
@gof.local_optimizer
([
T
.
_shape
])
...
...
theano/tensor/var.py
浏览文件 @
2fe4b0b8
...
...
@@ -3,7 +3,7 @@ import copy
import
numpy
import
theano
from
theano.compat
import
PY3
from
theano.compat
import
all
,
PY3
from
theano.scalar
import
ComplexError
,
IntegerDivisionError
from
theano.gof
import
Constant
,
Variable
from
theano.gof.utils
import
hashtype
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论