提交 689e5680 authored 作者: Frederic's avatar Frederic

[Test]test the forward of IncSubtensor when the new value is broadcasted implicitly.

上级 6d1d5fa5
...@@ -100,30 +100,42 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -100,30 +100,42 @@ class Test_inc_subtensor(unittest.TestCase):
sl2 = slice(sl2_end) sl2 = slice(sl2_end)
sl3 = 2 sl3 = 2
for do_set in [True, False]: val_a = numpy.ones((5, 3, 4))
print "Set", do_set val_inc = 2.3
val_sl2_end = 2
if do_set: for method in [tt.set_subtensor, tt.inc_subtensor]:
resut = tt.set_subtensor(a[sl1, sl3, sl2], increment) print "MethodSet", method
else:
resut = tt.inc_subtensor(a[sl1, sl3, sl2], increment)
f = theano.function([a, increment, sl2_end], resut) resut = method(a[sl1, sl3, sl2], increment)
val_a = numpy.ones((5, 3, 4)) f = theano.function([a, increment, sl2_end], resut)
val_inc = 2.3
val_sl2_end = 2
expected_result = numpy.copy(val_a) expected_result = numpy.copy(val_a)
result = f(val_a, val_inc, val_sl2_end) result = f(val_a, val_inc, val_sl2_end)
if do_set: if method is tt.set_subtensor:
expected_result[:, sl3, :val_sl2_end] = val_inc expected_result[:, sl3, :val_sl2_end] = val_inc
else: else:
expected_result[:, sl3, :val_sl2_end] += val_inc expected_result[:, sl3, :val_sl2_end] += val_inc
utt.assert_allclose(result, expected_result) utt.assert_allclose(result, expected_result)
# Test when we broadcast the result
resut = method(a[sl1, sl2], increment)
f = theano.function([a, increment, sl2_end], resut)
expected_result = numpy.copy(val_a)
result = f(val_a, val_inc, val_sl2_end)
if method is tt.set_subtensor:
expected_result[:, :val_sl2_end] = val_inc
else:
expected_result[:, :val_sl2_end] += val_inc
utt.assert_allclose(result, expected_result)
def test_grad_inc_set(self): def test_grad_inc_set(self):
def inc_slice(*s): def inc_slice(*s):
def just_numeric_args(a, b): def just_numeric_args(a, b):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论