提交 dcbb5804 authored 作者: Frederic Bastien's avatar Frederic Bastien

re-enable the grad on max in some case. The new max_and_argmax optimization do…

re-enable the grad on max in some case. The new max_and_argmax optimization do the manually done optimization that disabled the grad in some case.
上级 78c2c35a
...@@ -1471,11 +1471,8 @@ def max(x, axis=None): ...@@ -1471,11 +1471,8 @@ def max(x, axis=None):
:note: we return an error as numpy when we reduce a dim with a shape of 0 :note: we return an error as numpy when we reduce a dim with a shape of 0
:note2: see MaxAndArgmax note for a difference between numpy and theano when axis==None :note2: see MaxAndArgmax note for a difference between numpy and theano when axis==None
""" """
if isinstance(axis,int) or axis is None or (isinstance(axis,(list,tuple)) and all([isinstance(i,int) for i in axis])): if isinstance(axis,(list,tuple)) and len(axis)>1:
if axis is None: return CAReduce(scal.maximum,axis)(x)
axis = len(x.type.broadcastable)-1
return CAReduce(scal.maximum,axis)(x)
#TODO: do CAReduce need axis to be constant?
try: try:
const = get_constant_value(axis) const = get_constant_value(axis)
return CAReduce(scal.maximum,list(const))(x) return CAReduce(scal.maximum,list(const))(x)
......
...@@ -856,6 +856,11 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -856,6 +856,11 @@ class T_max_and_argmax(unittest.TestCase):
assert len(topo)==1 assert len(topo)==1
assert isinstance(topo[0].op,CAReduce) assert isinstance(topo[0].op,CAReduce)
f = function([n],max_and_argmax(n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,MaxAndArgmax)
def test_grad(self): def test_grad(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
...@@ -1016,7 +1021,7 @@ class T_max(unittest.TestCase): ...@@ -1016,7 +1021,7 @@ class T_max(unittest.TestCase):
assert isinstance(topo[0].op,CAReduce) assert isinstance(topo[0].op,CAReduce)
f(data) f(data)
def _test_grad(self): def test_grad(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
...@@ -1048,6 +1053,12 @@ class T_max(unittest.TestCase): ...@@ -1048,6 +1053,12 @@ class T_max(unittest.TestCase):
utt.verify_grad(lambda v: max(v.flatten()), [data]) utt.verify_grad(lambda v: max(v.flatten()), [data])
check_grad_max(data,eval_outputs(grad(max_and_argmax(n.flatten())[0],n))) check_grad_max(data,eval_outputs(grad(max_and_argmax(n.flatten())[0],n)))
@dec.knownfailureif(True,
"We don't implement the gradient of max with multiple axis as the same time")
def test_grad_list(self):
utt.verify_grad(lambda v: max(v,axis=[0,1]), [data])
#check_grad_max(data,eval_outputs(grad(max_and_argmax(n,axis=1)[0],n)),axis=1)
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self): def setUp(self):
Subtensor.debug = False Subtensor.debug = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论