提交 2f3d63cf authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change axis to be a class attr rather than an input on argmax.

上级 98cad04c
......@@ -14,7 +14,7 @@ import theano
from theano.compat import izip
from theano.configparser import config
from theano import gof
from theano.gof import Apply, Constant, Op, Variable
from theano.gof import Apply, Constant, Op, Variable, ParamsType
from theano.gof.type import Generic
from theano.tensor import elemwise
......@@ -1429,21 +1429,31 @@ class Argmax(Op):
nin = 2 # tensor, axis
nout = 1
E_axis = 'invalid axis'
__props__ = ()
__props__ = ('axis',)
_f16_ok = True
params_type = ParamsType(c_axis=scal.int64)
def __init__(self, axis):
if axis is not None:
axis = tuple(axis)
self.axis = tuple(axis)
def get_params(self, node):
if self.axis is not None and len(self.axis) == 1:
c_axis = np.int64(self.axis[0])
else:
# The value here doesn't matter, it won't be used
c_axis = np.int64(-1)
return self.params_type.get_params(c_axis=c_axis)
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
# Check axis and convert it to a Python list of integers.
axis = check_and_normalize_axes(x, axis)
if len(axis) == 0:
axis = NoneConst.clone()
if self.axis is None:
all_axes = list(range(x.ndim))
else:
all_axes = axis
axis = _as_tensor_variable(axis)
assert axis.ndim == 1
inputs = [x, axis]
all_axes = self.axis
inputs = [x]
# We keep the original broadcastable flags for dimensions on which
# we do not perform the argmax.
......@@ -1452,13 +1462,12 @@ class Argmax(Op):
outputs = [tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
x, axes = inp
def perform(self, node, inp, outs, params):
x, = inp
axes = self.axis
max_idx, = outs
if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axes)
# Numpy does not support multiple axes for argmax
# Work around
......@@ -1476,18 +1485,18 @@ class Argmax(Op):
dtype='int64')
def c_code(self, node, name, inp, out, sub):
x, axis = inp
x, = inp
argmax, = out
fail = sub["fail"]
if NoneConst.equals(node.inputs[1]):
params = sub["params"]
if self.axis is None:
axis_code = "axis = NPY_MAXDIMS;"
else:
assert node.inputs[1].ndim == 1
# Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1:
if len(self.axis) > 1:
raise NotImplementedError()
# params is only used here for now
axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
axis = %(params)s->c_axis;
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
......@@ -1522,28 +1531,20 @@ class Argmax(Op):
return ret % locals()
def c_code_cache_version(self):
return (0,)
return (1,)
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
axis = node.inputs[1]
if axis.data is None:
ishape, = shapes
if self.axis is None:
return [()]
rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i not in axis.data])
node.inputs[0].type.broadcastable) if i not in self.axis])
return [rval]
def grad(self, inp, grads):
x, axis = inp
axis_grad = grad_undefined(
self, 1, axis,
"argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined")
return [x.zeros_like(), axis_grad]
x, = inp
_argmax = Argmax()
return [x.zeros_like()]
def makeKeepDims(x, y, axis):
......
......@@ -60,7 +60,7 @@ def local_max_and_argmax(node):
return [new, None]
if len(node.outputs[0].clients) == 0:
return [None, T._argmax(node.inputs[0], axis)]
return [None, T.Argmax(axis)(node.inputs[0])]
@register_uncanonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论