提交 bc582508 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed the rest of the tests that depended on the Value class

上级 2483156c
...@@ -30,12 +30,13 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -30,12 +30,13 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq, inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor, Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast, tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast,
var, value, Join, shape, MaxAndArgmax, lscalar, zvector, exp, var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast) tile, patternbroadcast)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint
imported_scipy_special = False imported_scipy_special = False
...@@ -210,10 +211,14 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -210,10 +211,14 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.good.items(): for testname, inputs in self.good.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [ TensorType( dtype = input.dtype, broadcastable =
[ shape_elem == 1 for shape_elem in input.shape]
)() for input in inputs]
try: try:
#node = self.op.make_node(*inputrs) #node = self.op.make_node(*inputrs)
node = safe_make_node(self.op, *inputrs) node = safe_make_node(self.op, *inputrs)
print 'node: '
print node
except Exception, exc: except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while" err_msg = ("Test %s::%s: Error occurred while"
" making a node with inputs %s") % ( " making a node with inputs %s") % (
...@@ -223,6 +228,13 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -223,6 +228,13 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
try: try:
f = inplace_func(inputrs, node.outputs, mode=mode) f = inplace_func(inputrs, node.outputs, mode=mode)
try:
for i, output in f.maker.env.outputs:
print 'output',i
debugprint(output)
except:
print 'only one output?'
debugprint(f.maker.env.outputs)
except Exception, exc: except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while" err_msg = ("Test %s::%s: Error occurred while"
" trying to make a Function") % (self.op, testname) " trying to make a Function") % (self.op, testname)
...@@ -287,7 +299,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -287,7 +299,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.bad_build.items(): for testname, inputs in self.bad_build.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [shared(input) for input in inputs]
self.assertRaises(Exception, self.assertRaises(Exception,
safe_make_node, self.op, *inputrs) safe_make_node, self.op, *inputrs)
# The old error string was ("Test %s::%s: %s was successfully # The old error string was ("Test %s::%s: %s was successfully
...@@ -299,7 +311,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -299,7 +311,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.bad_runtime.items(): for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [shared(input) for input in inputs]
try: try:
node = safe_make_node(self.op, *inputrs) node = safe_make_node(self.op, *inputrs)
except Exception, exc: except Exception, exc:
...@@ -310,7 +322,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -310,7 +322,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise raise
try: try:
f = inplace_func(inputrs, node.outputs, mode=mode) f = inplace_func([], node.outputs, mode=mode)
except Exception, exc: except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while trying" err_msg = ("Test %s::%s: Error occurred while trying"
" to make a Function") % (self.op, testname) " to make a Function") % (self.op, testname)
...@@ -321,7 +333,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -321,7 +333,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
# one? # one?
# TODO: test that only this one is raised and catch only this # TODO: test that only this one is raised and catch only this
# one or the subset that get raised. # one or the subset that get raised.
self.assertRaises(Exception, f, *inputs) self.assertRaises(Exception, f, [])
def test_grad(self): def test_grad(self):
if skip: if skip:
...@@ -332,7 +344,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -332,7 +344,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
try: try:
for testname, inputs in self.grad.items(): for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] #inputrs = [shared(input) for input in inputs]
try: try:
utt.verify_grad(self.op, inputs, utt.verify_grad(self.op, inputs,
mode=self.mode, mode=self.mode,
...@@ -3796,17 +3808,17 @@ class T_add(unittest.TestCase): ...@@ -3796,17 +3808,17 @@ class T_add(unittest.TestCase):
def test_complex_all_ops(self): def test_complex_all_ops(self):
for nbits in (64, 128): for nbits in (64, 128):
a = value(numpy.ones(3, dtype='complex%i' % nbits)+0.5j) a = shared(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
b = value(numpy.ones(3, dtype='complex%i' % nbits)+1.5j) b = shared(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
tests = (("+", lambda x,y: x+y), tests = (("+", lambda x,y: x+y),
("-", lambda x,y: x-y), ("-", lambda x,y: x-y),
("*", lambda x,y: x*y), ("*", lambda x,y: x*y),
("/", lambda x,y: x/y)) ("/", lambda x,y: x/y))
for s, fn in tests: for s, fn in tests:
f = inplace_func([a,b], fn(a, b)) f = inplace_func([], fn(a, b))
#print 'valid output:', fn(a.data, b.data) #print 'valid output:', fn(a.data, b.data)
#print 'theano output:', f(a.data, b.data) #print 'theano output:', f(a.data, b.data)
self.assertTrue(a.type.values_eq_approx(fn(a.data, b.data), f(a.data, b.data))) self.assertTrue(a.type.values_eq_approx(fn(a.get_value(), b.get_value()), f()))
def test_grad_scalar_l(self): def test_grad_scalar_l(self):
utt.verify_grad(add, [numpy.asarray([3.0]), rand(3)]) utt.verify_grad(add, [numpy.asarray([3.0]), rand(3)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论