提交 471aa985 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Ignore unused inputs in some tests.

上级 287ffba8
...@@ -54,26 +54,45 @@ class T_function(unittest.TestCase): ...@@ -54,26 +54,45 @@ class T_function(unittest.TestCase):
def test_missing_inputs(self): def test_missing_inputs(self):
MissingInputException = TypeError MissingInputException = TypeError
UnusedInputException = ValueError
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([], [x]) fn = function([], [x])
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputException)
def fn():
x,s = T.scalars('xs')
# Ignore unused input s, as it hides the other error
fn = function([s], [x], on_unused_input='ignore')
checkfor(self, fn, MissingInputException)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([s], [x]) fn = function([s], [x])
checkfor(self, fn, UnusedInputException)
def fn():
x,s = T.scalars('xs')
# Ignore unused input s, as it hides the other error
fn = function([s], x, on_unused_input='ignore')
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputException)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([s], x) fn = function([s], x)
checkfor(self, fn, UnusedInputException)
def fn():
x,s = T.scalars('xs')
# Ignore unused input s, as it hides the other error
fn = function([s], Out(x), on_unused_input='ignore')
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputException)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([s], Out(x)) fn = function([s], Out(x))
checkfor(self, fn, MissingInputException) checkfor(self, fn, UnusedInputException)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
...@@ -124,7 +143,8 @@ class T_function(unittest.TestCase): ...@@ -124,7 +143,8 @@ class T_function(unittest.TestCase):
x,s = T.scalars('xs') x,s = T.scalars('xs')
#x's name is ignored because it is followed by anonymous parameter a. #x's name is ignored because it is followed by anonymous parameter a.
f = function([x, a, s], a/s) # Ignore unused input x, as it hides the other error
f = function([x, a, s], a/s, on_unused_input='ignore')
self.assertTrue(f(9,1,2) == 0.5) self.assertTrue(f(9,1,2) == 0.5)
self.assertTrue(f(9,2,1) == 2.0) self.assertTrue(f(9,2,1) == 2.0)
self.assertTrue(f(9,2, s=1) == 2.0) self.assertTrue(f(9,2, s=1) == 2.0)
...@@ -355,6 +375,20 @@ class T_function(unittest.TestCase): ...@@ -355,6 +375,20 @@ class T_function(unittest.TestCase):
f(o+.1) #should clobber the memory used to store four f(o+.1) #should clobber the memory used to store four
assert not numpy.all(four==4) assert not numpy.all(four==4)
def test_disconnected_input(self):
a = T.scalar('a')
v = T.vector('v')
self.assertRaises(ValueError, function, [a, v], v*2)
f = function([a, v], v*2, on_unused_input='ignore')
def test_masked_input(self):
m = T.matrix('m')
mt = m.T
mt.name = 'm.T'
self.assertRaises(ValueError, function, [m, mt], mt*2)
f = function([m, mt], mt*2, on_unused_input='ignore')
class T_picklefunction(unittest.TestCase): class T_picklefunction(unittest.TestCase):
def test_deepcopy(self): def test_deepcopy(self):
...@@ -631,7 +665,6 @@ class T_picklefunction(unittest.TestCase): ...@@ -631,7 +665,6 @@ class T_picklefunction(unittest.TestCase):
assert blah.f1[blah.s] != blah2.f1[blah2.s] assert blah.f1[blah.s] != blah2.f1[blah2.s]
class SomethingToPickle(object): class SomethingToPickle(object):
def __init__(self): def __init__(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
......
...@@ -57,13 +57,15 @@ else: ...@@ -57,13 +57,15 @@ else:
utt.seed_rng() utt.seed_rng()
def inplace_func(inputs, outputs, mode=None, allow_input_downcast=False): def inplace_func(inputs, outputs, mode=None, allow_input_downcast=False,
on_unused_input='raise'):
if mode is None: if mode is None:
mode = get_default_mode() mode = get_default_mode()
return function(inputs, outputs, return function(inputs, outputs,
mode=mode, mode=mode,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
accept_inplace=True) accept_inplace=True,
on_unused_input=on_unused_input)
def eval_outputs(outputs): def eval_outputs(outputs):
......
...@@ -399,7 +399,8 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0): ...@@ -399,7 +399,8 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
f = inplace_func( f = inplace_func(
[Param(ii, mutable=True, allow_downcast=True) for ii in i], [Param(ii, mutable=True, allow_downcast=True) for ii in i],
o, o,
mode='FAST_RUN') mode='FAST_RUN',
on_unused_input='ignore')
at_least_one_gemm = False at_least_one_gemm = False
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: if node.op == T.dot:
...@@ -410,7 +411,7 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0): ...@@ -410,7 +411,7 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
at_least_one_gemm = True at_least_one_gemm = True
assert at_least_one_gemm assert at_least_one_gemm
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None), g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True) allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.env.nodes: for node in g.maker.env.nodes:
if node.op == gemm_inplace: if node.op == gemm_inplace:
raise Exception('gemm_inplace in original graph') raise Exception('gemm_inplace in original graph')
...@@ -475,11 +476,12 @@ def test_gemm_opt_double_gemm(): ...@@ -475,11 +476,12 @@ def test_gemm_opt_double_gemm():
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))] + gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))]
try: try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, f = inplace_func([Param(ii, mutable=True) for ii in i],o,
mode='FAST_RUN') mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: raise Failure('dot in graph') if node.op == T.dot: raise Failure('dot in graph')
if node.op == _dot22: raise Failure('_dot22 in graph') if node.op == _dot22: raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None)) g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore')
#for node in g.maker.env.nodes: #for node in g.maker.env.nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph') # if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
......
...@@ -3256,7 +3256,8 @@ class T_local_sum_dimshuffle(unittest.TestCase): ...@@ -3256,7 +3256,8 @@ class T_local_sum_dimshuffle(unittest.TestCase):
try: try:
for i, s in enumerate(sums): for i, s in enumerate(sums):
print i print i
f = theano.function([a, b, c, d], s, mode=self.mode) f = theano.function([a, b, c, d], s, mode=self.mode,
on_unused_input='ignore')
theano.printing.debugprint(f) theano.printing.debugprint(f)
g = f.maker.env.toposort() g = f.maker.env.toposort()
#print 'g =', g #print 'g =', g
...@@ -3294,7 +3295,7 @@ def test_make_vector(): ...@@ -3294,7 +3295,7 @@ def test_make_vector():
]: ]:
mv = opt.MakeVector(dtype=dtype)(*inputs) mv = opt.MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype assert mv.dtype == dtype
f = theano.function([b, i, d], mv) f = theano.function([b, i, d], mv, on_unused_input='ignore')
f_val = f(val[b], val[i], val[d]) f_val = f(val[b], val[i], val[d])
#print 'f_val =', f_val #print 'f_val =', f_val
......
...@@ -106,12 +106,12 @@ class RopLop_checker(unittest.TestCase): ...@@ -106,12 +106,12 @@ class RopLop_checker(unittest.TestCase):
vv = numpy.asarray(self.rng.uniform(size=self.mat_in_shape), vv = numpy.asarray(self.rng.uniform(size=self.mat_in_shape),
theano.config.floatX) theano.config.floatX)
yv = tensor.Rop(y, self.mx, self.mv) yv = tensor.Rop(y, self.mx, self.mv)
rop_f = function([self.mx, self.mv], yv) rop_f = function([self.mx, self.mv], yv, on_unused_input='ignore')
sy, _ = theano.scan(lambda i, y, x, v: \ sy, _ = theano.scan(lambda i, y, x, v: \
(tensor.grad(y[i], x) * v).sum(), (tensor.grad(y[i], x) * v).sum(),
sequences=tensor.arange(y.shape[0]), sequences=tensor.arange(y.shape[0]),
non_sequences=[y, self.mx, self.mv]) non_sequences=[y, self.mx, self.mv])
scan_f = function([self.mx, self.mv], sy) scan_f = function([self.mx, self.mv], sy, on_unused_input='ignore')
v1 = rop_f(vx, vv) v1 = rop_f(vx, vv)
v2 = scan_f(vx, vv) v2 = scan_f(vx, vv)
...@@ -146,13 +146,13 @@ class RopLop_checker(unittest.TestCase): ...@@ -146,13 +146,13 @@ class RopLop_checker(unittest.TestCase):
theano.config.floatX) theano.config.floatX)
yv = tensor.Rop(y, self.x, self.v) yv = tensor.Rop(y, self.x, self.v)
rop_f = function([self.x, self.v], yv) rop_f = function([self.x, self.v], yv, on_unused_input='ignore')
J, _ = theano.scan(lambda i, y, x: tensor.grad(y[i], x), J, _ = theano.scan(lambda i, y, x: tensor.grad(y[i], x),
sequences=tensor.arange(y.shape[0]), sequences=tensor.arange(y.shape[0]),
non_sequences=[y, self.x]) non_sequences=[y, self.x])
sy = tensor.dot(J, self.v) sy = tensor.dot(J, self.v)
scan_f = function([self.x, self.v], sy) scan_f = function([self.x, self.v], sy, on_unused_input='ignore')
v1 = rop_f(vx, vv) v1 = rop_f(vx, vv)
v2 = scan_f(vx, vv) v2 = scan_f(vx, vv)
...@@ -168,7 +168,7 @@ class RopLop_checker(unittest.TestCase): ...@@ -168,7 +168,7 @@ class RopLop_checker(unittest.TestCase):
theano.config.floatX) theano.config.floatX)
yv = tensor.Lop(y, self.x, self.v) yv = tensor.Lop(y, self.x, self.v)
lop_f = function([self.x, self.v], yv) lop_f = function([self.x, self.v], yv, on_unused_input='ignore')
J, _ = theano.scan(lambda i, y, x: tensor.grad(y[i], x), J, _ = theano.scan(lambda i, y, x: tensor.grad(y[i], x),
sequences=tensor.arange(y.shape[0]), sequences=tensor.arange(y.shape[0]),
non_sequences=[y, self.x]) non_sequences=[y, self.x])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论