提交 2f3d63cf authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change axis to be a class attr rather than an input on argmax.

上级 98cad04c
...@@ -14,7 +14,7 @@ import theano ...@@ -14,7 +14,7 @@ import theano
from theano.compat import izip from theano.compat import izip
from theano.configparser import config from theano.configparser import config
from theano import gof from theano import gof
from theano.gof import Apply, Constant, Op, Variable from theano.gof import Apply, Constant, Op, Variable, ParamsType
from theano.gof.type import Generic from theano.gof.type import Generic
from theano.tensor import elemwise from theano.tensor import elemwise
...@@ -1429,21 +1429,31 @@ class Argmax(Op): ...@@ -1429,21 +1429,31 @@ class Argmax(Op):
nin = 2 # tensor, axis nin = 2 # tensor, axis
nout = 1 nout = 1
E_axis = 'invalid axis' E_axis = 'invalid axis'
__props__ = () __props__ = ('axis',)
_f16_ok = True _f16_ok = True
params_type = ParamsType(c_axis=scal.int64)
def __init__(self, axis):
if axis is not None:
axis = tuple(axis)
self.axis = tuple(axis)
def get_params(self, node):
if self.axis is not None and len(self.axis) == 1:
c_axis = np.int64(self.axis[0])
else:
# The value here doesn't matter, it won't be used
c_axis = np.int64(-1)
return self.params_type.get_params(c_axis=c_axis)
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
# Check axis and convert it to a Python list of integers. if self.axis is None:
axis = check_and_normalize_axes(x, axis)
if len(axis) == 0:
axis = NoneConst.clone()
all_axes = list(range(x.ndim)) all_axes = list(range(x.ndim))
else: else:
all_axes = axis all_axes = self.axis
axis = _as_tensor_variable(axis) inputs = [x]
assert axis.ndim == 1
inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which # We keep the original broadcastable flags for dimensions on which
# we do not perform the argmax. # we do not perform the argmax.
...@@ -1452,13 +1462,12 @@ class Argmax(Op): ...@@ -1452,13 +1462,12 @@ class Argmax(Op):
outputs = [tensor('int64', broadcastable, name='argmax')] outputs = [tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inp, outs): def perform(self, node, inp, outs, params):
x, axes = inp x, = inp
axes = self.axis
max_idx, = outs max_idx, = outs
if axes is None: if axes is None:
axes = tuple(range(x.ndim)) axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axes)
# Numpy does not support multiple axes for argmax # Numpy does not support multiple axes for argmax
# Work around # Work around
...@@ -1476,18 +1485,18 @@ class Argmax(Op): ...@@ -1476,18 +1485,18 @@ class Argmax(Op):
dtype='int64') dtype='int64')
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, axis = inp x, = inp
argmax, = out argmax, = out
fail = sub["fail"] fail = sub["fail"]
if NoneConst.equals(node.inputs[1]): params = sub["params"]
if self.axis is None:
axis_code = "axis = NPY_MAXDIMS;" axis_code = "axis = NPY_MAXDIMS;"
else: else:
assert node.inputs[1].ndim == 1 if len(self.axis) > 1:
# Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1:
raise NotImplementedError() raise NotImplementedError()
# params is only used here for now
axis_code = """ axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0]; axis = %(params)s->c_axis;
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){ if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument"); "Argmax, bad axis argument");
...@@ -1522,28 +1531,20 @@ class Argmax(Op): ...@@ -1522,28 +1531,20 @@ class Argmax(Op):
return ret % locals() return ret % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (1,)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
ishape, axis_shape = shapes ishape, = shapes
axis = node.inputs[1] if self.axis is None:
if axis.data is None:
return [()] return [()]
rval = tuple([ishape[i] for (i, b) in enumerate( rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i not in axis.data]) node.inputs[0].type.broadcastable) if i not in self.axis])
return [rval] return [rval]
def grad(self, inp, grads): def grad(self, inp, grads):
x, axis = inp x, = inp
axis_grad = grad_undefined(
self, 1, axis,
"argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined")
return [x.zeros_like(), axis_grad]
_argmax = Argmax() return [x.zeros_like()]
def makeKeepDims(x, y, axis): def makeKeepDims(x, y, axis):
......
...@@ -60,7 +60,7 @@ def local_max_and_argmax(node): ...@@ -60,7 +60,7 @@ def local_max_and_argmax(node):
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0: if len(node.outputs[0].clients) == 0:
return [None, T._argmax(node.inputs[0], axis)] return [None, T.Argmax(axis)(node.inputs[0])]
@register_uncanonicalize @register_uncanonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论