提交 f1a541b3 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed max/argmax gradient tests in float32

上级 074a40d2
......@@ -6,6 +6,9 @@ import sys
import unittest
import warnings
from copy import copy, deepcopy
# Import builtin min to be able to use it after importing the tensor version.
import __builtin__
builtin_min = __builtin__.min
from nose.plugins.skip import SkipTest
import numpy
......@@ -1562,6 +1565,31 @@ class T_max_and_argmax(unittest.TestCase):
data = rand(2, 3)
n = as_tensor_variable(data)
def safe_verify_grad(func, data):
"""
Wrapper around 'verify_grad' that picks a proper value for epsilon.
This is needed because 'verify_grad' may fail when its epsilon is
too large, due to the fact the argmax is not continuous.
We make sure epsilon is less than the minimum absolute value found
in the matrix of pairwise differences between all elements in the
data. This way, the argmax will not change when adding epsilon.
"""
# 'data' is a one-element list.
data_tensor, = data
# Flatten it into a 1D vector.
data_vector = data_tensor.flatten()
# Compute pairwise absolute differences.
diff = numpy.abs(data_vector.reshape((-1, 1)) - data_vector)
# Alter the diagonal to avoid a zero minimum.
for i in xrange(len(diff)):
diff[i, i] = 1
# Find an appropriate epsilon.
eps = builtin_min(numeric_grad.type_eps[config.floatX],
diff.min() / 2)
# Run gradient verification.
utt.verify_grad(func, data, eps=eps)
def check_grad_max(data, max_grad_data, axis=None):
"""
Why this is needed? verify_grad is not enough?
......@@ -1583,10 +1611,10 @@ class T_max_and_argmax(unittest.TestCase):
for axis in (-1, 0, 1, None):
for j in xrange(2):
utt.verify_grad(lambda v: max_and_argmax(v, axis=axis)[j],
safe_verify_grad(lambda v: max_and_argmax(v, axis=axis)[j],
[data])
if axis != 1:
utt.verify_grad(lambda v: max_and_argmax(v.flatten(),
safe_verify_grad(lambda v: max_and_argmax(v.flatten(),
axis=axis)[j],
[data])
if axis in (0, None):
......@@ -1599,16 +1627,15 @@ class T_max_and_argmax(unittest.TestCase):
data = rand(3, 4, 5)
for i in [0, 1, 2]:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
# Test 4d inner dimensions
# Use float64 as otherwise the test does not pass.
data = rand(2, 3, 4, 5).astype("float64")
data = rand(2, 3, 4, 5)
for i in [0, 1, 2, 3]:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
class T_argmin_argmax(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论