提交 c64f6e87 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added tests for shared variable contract for tensor and sparse

上级 b577b992
...@@ -400,5 +400,99 @@ class test_structureddot(unittest.TestCase): ...@@ -400,5 +400,99 @@ class test_structureddot(unittest.TestCase):
self.failUnless(numpy.allclose(theano_result, scipy_result)) self.failUnless(numpy.allclose(theano_result, scipy_result))
self.failIf(theano_time > overhead_rtol*scipy_time + overhead_tol) self.failIf(theano_time > overhead_rtol*scipy_time + overhead_tol)
def test_shared_dont_alias():
rng = numpy.random.RandomState([3,5,17])
x_lil = random_lil((2,4), 'float64', 5)
x = sp.csr_matrix(x_lil)
densifier = rng.randn(*x.shape)
x_shared = theano.shared(x, borrow = False)
prod = theano.sparse.mul_s_d(x_shared,densifier)
prod_func = theano.function([],prod)
prod_val = prod_func().todense()
x_dense_fake = x.todense()
x_dense = numpy.asarray(x_dense_fake)
scipy_prod_val = x_dense * densifier
assert numpy.allclose(scipy_prod_val, prod_val)
for i in xrange(x.shape[0]):
for j in xrange(x.shape[1]):
if x[i,j] != 0:
x[i,j] += 1
prod_val_2 = prod_func().todense()
#value used to construct should not alias with internal
assert numpy.allclose(prod_val_2,prod_val)
x = x_shared.get_value(borrow = False)
for i in xrange(x.shape[0]):
for j in xrange(x.shape[1]):
if x[i,j] != 0:
x[i,j] += 1
prod_val_3 = prod_func().todense()
#value returned by access should not alias with internal
assert numpy.allclose(prod_val, prod_val_3)
#in this case we can alias
x = x_shared.get_value(borrow = True)
for i in xrange(x.shape[0]):
for j in xrange(x.shape[1]):
if x[i,j] != 0:
x[i,j] += 1
x_dense_fake = x.todense()
x_dense = numpy.asarray(x_dense_fake)
scipy_prod_val = x_dense * densifier
#this is not required by the contract but it is a feature we've implemented
assert numpy.allclose(scipy_prod_val, prod_func().todense())
def test_shared_do_alias():
rng = numpy.random.RandomState([2,4,16])
x_lil = random_lil((2,4), 'float64', 5)
x = sp.csr_matrix(x_lil)
x_shared = theano.shared(x, borrow = True)
densifier = rng.randn(*x.shape)
prod = theano.sparse.mul_s_d(x_shared,densifier)
prod_func = theano.function([],prod)
prod_val = prod_func().todense()
x_dense_fake = x.todense()
x_dense = numpy.asarray(x_dense_fake)
scipy_prod_val = x_dense * densifier
assert numpy.allclose(scipy_prod_val, prod_val)
for i in xrange(x.shape[0]):
for j in xrange(x.shape[1]):
if x[i,j] != 0:
x[i,j] += 1
x_dense_fake = x.todense()
x_dense = numpy.asarray(x_dense_fake)
scipy_prod_val = x_dense * densifier
#not required by the contract but it is a feature we've implemented
assert numpy.allclose(scipy_prod_val, prod_func().todense())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3366,6 +3366,62 @@ def test_dimshuffle_duplicate(): ...@@ -3366,6 +3366,62 @@ def test_dimshuffle_duplicate():
assert success assert success
def test_shared_dont_alias():
rng = numpy.random.RandomState([3,5,17])
x = rng.uniform(0,1,[2,4])
x_shared = theano.shared(x, borrow = False)
total = theano.tensor.sum(x_shared)
total_func = theano.function([],total)
total_val = total_func()
assert numpy.allclose(x.sum(), total_val)
x += 1
total_val_2 = total_func()
#value used to construct should not alias with internal
assert total_val == total_val_2
x = x_shared.get_value(borrow = False)
x += 1
total_val_3 = total_func()
#value returned by access should not alias with internal
assert total_val == total_val_3
#in this case we can alias
x = x_shared.get_value(borrow = True)
x += 1
#this is not required by the contract but it is a feature we've implemented
assert numpy.allclose(x.sum(), total_func())
def test_shared_do_alias():
rng = numpy.random.RandomState([2,4,16])
x = rng.uniform(1,2,[4,2])
x_shared = theano.shared(x, borrow = True)
total = theano.tensor.sum(x_shared)
total_func = theano.function([],total)
total_val = total_func()
assert numpy.allclose(x.sum(), total_val)
x += 1
#not required by the contract but it is a feature we've implemented
assert numpy.allclose(x.sum(), total_func())
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论