提交 2e5a0886 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test for the new fix

上级 3a2138de
......@@ -18,6 +18,8 @@ import theano.scalar as scal
import theano.tensor as tensor
from theano.tests import unittest_tools as utt
from theano.tensor.subtensor import (inc_subtensor, set_subtensor,
advanced_inc_subtensor1,
advanced_set_subtensor1,
Subtensor, IncSubtensor,
AdvancedSubtensor1, AdvancedSubtensor,
advanced_subtensor1, inplace_increment,
......@@ -519,6 +521,19 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.assertTrue(g_00.shape == (1, 3))
self.assertTrue(numpy.allclose(g_00, 2))
utt.verify_grad(lambda m: m[[1, 3]],
[numpy.random.rand(5, 5).astype(self.dtype)])
def fun(x, y):
return advanced_inc_subtensor1(x, y, [1, 3])
utt.verify_grad(fun, [numpy.random.rand(5, 5).astype(self.dtype),
numpy.random.rand(2, 5).astype(self.dtype)])
def fun(x, y):
return advanced_set_subtensor1(x, y, [1, 3])
utt.verify_grad(fun, [numpy.random.rand(5, 5).astype(self.dtype),
numpy.random.rand(2, 5).astype(self.dtype)])
def test_adv_sub1_idx_broadcast(self):
# The idx can be a broadcastable vector.
ones = numpy.ones((4, 3), dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论