提交 c8615bdc authored 作者: Frederic's avatar Frederic

Update test T_subtensor.test_advanced_inc_and_set to work correctly.

上级 2c8c27d3
...@@ -2100,23 +2100,20 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -2100,23 +2100,20 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
return super(T_subtensor, self).__init__(name) return super(T_subtensor, self).__init__(name)
def function(self, inputs, outputs, accept_inplace=False, def function(self, inputs, outputs, accept_inplace=False,
op=None, mode=None, N=1): op=None, mode=None, N=1, N_fast=None):
""" wrapper around theano.function that also check the output """ wrapper around theano.function that also check the output
:param N: the number of op expected in the toposort :param N: the number of op expected in the toposort
if tuple of length 2, (expected if fast_compile, if tuple of length 2, (expected if fast_compile,
if not fast_compile) if not fast_compile)
""" """
if isinstance(N, tuple): if self.fast_compile and N_fast is not None:
assert len(N) == 2 N = N_fast
if self.fast_compile:
N = N[0]
else:
N = N[1]
if mode is None: if mode is None:
mode = self.mode mode = self.mode
if op is None: if op is None:
op = self.sub op = self.sub
f = theano.function(inputs, outputs, mode=mode, f = theano.function(inputs, outputs, mode=mode,
accept_inplace=accept_inplace) accept_inplace=accept_inplace)
self.assertFunctionContainsClassN(f, op, N) self.assertFunctionContainsClassN(f, op, N)
...@@ -2694,7 +2691,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -2694,7 +2691,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
if idx is idxs[0]: if idx is idxs[0]:
f = self.function([], [gn.shape, n[idx_].shape], f = self.function([], [gn.shape, n[idx_].shape],
op=ops, op=ops,
N=(2, 0)) N=0, N_fast=2)
f() f()
def test_wrong_exception_regression(self): def test_wrong_exception_regression(self):
...@@ -2747,7 +2744,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -2747,7 +2744,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
data = numpy.asarray(data, dtype=self.dtype) data = numpy.asarray(data, dtype=self.dtype)
n = self.shared(data) n = self.shared(data)
t = n[idx] t = n[idx]
f = self.function([], t.shape, op=self.ops, N=(1, 0)) f = self.function([], t.shape, op=self.ops, N=0, N_fast=1)
val = f() val = f()
self.assertTrue(numpy.allclose(val, data[idx].shape)) self.assertTrue(numpy.allclose(val, data[idx].shape))
...@@ -2850,6 +2847,8 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -2850,6 +2847,8 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
data_copy[idx] = inc_num data_copy[idx] = inc_num
else: else:
data_copy[idx] += inc_num data_copy[idx] += inc_num
data_var = theano.In(data_var, mutable=True)
# Remember data for the Theano function (see below). # Remember data for the Theano function (see below).
all_inputs_var += [data_var, idx_var, inc_var] all_inputs_var += [data_var, idx_var, inc_var]
all_inputs_num += [data_num, idx_num, inc_num] all_inputs_num += [data_num, idx_num, inc_num]
...@@ -2869,9 +2868,16 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -2869,9 +2868,16 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
assert (data_num == data_num_init).all() assert (data_num == data_num_init).all()
# Actual test (we compile a single Theano function to make it faster). # Actual test (we compile a single Theano function to make it faster).
f = self.function(all_inputs_var, all_outputs_var, orig_warn = theano.config.warn.gpu_set_subtensor1
accept_inplace=True, op=self.adv_incsub1, try:
N=len(all_outputs_var)) theano.config.warn.gpu_set_subtensor1 = False
f = self.function(all_inputs_var, all_outputs_var,
accept_inplace=True,
op=self.adv_incsub1,
N=len(all_outputs_var))
finally:
theano.config.warn.gpu_set_subtensor1 = orig_warn
f_outs = f(*all_inputs_num) f_outs = f(*all_inputs_num)
assert len(f_outs) == len(all_outputs_num) assert len(f_outs) == len(all_outputs_num)
for f_out, output_num in izip(f_outs, all_outputs_num): for f_out, output_num in izip(f_outs, all_outputs_num):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论