提交 a7922a1f authored 作者: Harm de Vries's avatar Harm de Vries

add grad tests

上级 052ddd0e
...@@ -133,6 +133,7 @@ class test_sort(unittest.TestCase): ...@@ -133,6 +133,7 @@ class test_sort(unittest.TestCase):
data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX) data = np.random.rand(2, 3, 4, 2).astype(theano.config.floatX)
utt.verify_grad(lambda x: sort(x, 3), [data]) utt.verify_grad(lambda x: sort(x, 3), [data])
class TensorInferShapeTester(utt.InferShapeTester): class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self): def test_sort(self):
x = tensor.matrix() x = tensor.matrix()
...@@ -209,3 +210,13 @@ def test_argsort(): ...@@ -209,3 +210,13 @@ def test_argsort():
assert np.allclose(gv, gt) assert np.allclose(gv, gt)
def test_argsort_grad():
# Testing grad of argsort
data = np.random.rand(2, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=-1), [data])
data = np.random.rand(2, 3, 4, 5).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=-3), [data])
data = np.random.rand(2, 3, 3).astype(theano.config.floatX)
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论