提交 7c5f01ae authored 作者: James Bergstra's avatar James Bergstra

added tensor functions: min, argmin, smallest. Fixed scalar return-value bug in MaxAndArgMax

上级 824f5fa7
...@@ -809,8 +809,8 @@ class MaxAndArgmax(Op): ...@@ -809,8 +809,8 @@ class MaxAndArgmax(Op):
tensor(axis.type.dtype, broadcastable)] tensor(axis.type.dtype, broadcastable)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, (x, axis), (max, max_idx)): def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.max(x, axis) max[0] = numpy.asarray(numpy.max(x, axis))
max_idx[0] = numpy.argmax(x, axis) max_idx[0] = numpy.asarray(numpy.argmax(x, axis))
def grad(self, (x, axis), (g_max, g_max_idx)): def grad(self, (x, axis), (g_max, g_max_idx)):
# @warning: This only works if axis is 0, else the max is # @warning: This only works if axis is 0, else the max is
# broadcasted wrong in the call to eq. # broadcasted wrong in the call to eq.
...@@ -859,6 +859,27 @@ def argmax(x, axis=None): ...@@ -859,6 +859,27 @@ def argmax(x, axis=None):
# but when Argmax.c_impl() is in place, it should be fine. # but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x,axis)[1] return max_and_argmax(x,axis)[1]
@constructor
def min(x, axis=None):
if 'float'in str(x.dtype):
return -max(-x, axis=axis)
else:
#Be careful about unsigned integers, complex
raise NotImplementedError()
@constructor
def argmin(x, axis=None):
if 'float'in str(x.dtype):
return argmax(-x, axis=axis)
else:
#Be careful about unsigned integers, complex
raise NotImplementedError()
@constructor
def smallest(*args):
"""Return the [elementwise] smallest of a variable number of arguments (like python's min)."""
return min(stack(*args), axis=0)
########################## ##########################
# Comparison # Comparison
......
...@@ -1767,6 +1767,27 @@ class test_tensordot(unittest.TestCase): ...@@ -1767,6 +1767,27 @@ class test_tensordot(unittest.TestCase):
f6(bval,aval))) f6(bval,aval)))
tensor.verify_grad(None, TensorDot(axes), [bval,aval]) tensor.verify_grad(None, TensorDot(axes), [bval,aval])
def test_smallest_stack():
sx, sy = dscalar(), dscalar()
rval = function([sx,sy], stack(sx,sy))(-4.0, -2.0)
assert type(rval) == numpy.ndarray
assert [-4, -2] == list(rval)
def test_smallest():
x = dvector()
y = dvector()
z = dvector()
f1 = function([x], smallest(x))
assert numpy.all([1,2,3] == f1([1,2,3]))
f3 = function([x,y,z], smallest(x,y,z))
assert numpy.all([1,2,3] == f3([1,3,9], [7,7,7], [8,2,3]))
sx, sy = dscalar(), dscalar()
assert -4 == function([sx,sy], smallest(sx,sy))(-4.0, -2.0)
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT': if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论