提交 15584c1d authored 作者: Jakub Sygnowski's avatar Jakub Sygnowski

optimizer for substituting max_and_argmax with just argmax

上级 88160c11
......@@ -1732,26 +1732,6 @@ class Argmax(Op):
node.inputs[0].type.broadcastable) if i not in axis.data])
return [rval]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
if not isinstance(inputs[1], theano.Constant):
raise ValueError(('R_op supported for argmax only for '
'constant axis!'))
if inputs[1].data > 1:
raise ValueError(('R_op supported for argmax only when '
' axis is 0 or 1'))
if inputs[0].ndim != 2:
raise ValueError(('R_op supported for argmax only when '
' input is a matrix'))
max_pos = self.make_node(*inputs).outputs
if inputs[1].data == 0:
return [eval_points[0][max_pos,
arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]),
max_pos], None]
def grad(self, inp, grads):
x, axis = inp
......@@ -1887,7 +1867,7 @@ def argmax(x, axis=None, keepdims=False):
will broadcast correctly against the original tensor.
"""
argout = _argmax(x, axis)
argout = max_and_argmax(x, axis)[1]
if keepdims:
argout = makeKeepDims(x, argout, axis)
......
......@@ -1314,10 +1314,10 @@ def test_argmax_pushdown():
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second
assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
assert fgraph.toposort()[0].op == tensor.basic._argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._max_and_argmax)
fgraph, ops_to_check=tensor.basic._argmax)
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
......@@ -1362,7 +1362,7 @@ def test_argmax_pushdown_bias():
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.MaxAndArgmax)
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 4
for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type)
......
......@@ -73,6 +73,9 @@ def local_max_and_argmax(node):
new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None]
if len(node.outputs[0].clients) == 0:
return [None, T._argmax(node.inputs[0], node.inputs[1])]
@register_uncanonicalize
@gof.local_optimizer([T.neg])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论