提交 b68c23d6 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add little test for cast.

上级 d986f790
...@@ -33,7 +33,7 @@ from theano.sparse import ( ...@@ -33,7 +33,7 @@ from theano.sparse import (
csc_from_dense, csr_from_dense, dense_from_sparse, csc_from_dense, csr_from_dense, dense_from_sparse,
Dot, Usmm, UsmmCscDense, sp_ones_like, GetItemScalar, Dot, Usmm, UsmmCscDense, sp_ones_like, GetItemScalar,
SparseFromDense, SparseFromDense,
Cast, HStack, VStack, AddSSData, add_s_s_data, Cast, cast, HStack, VStack, AddSSData, add_s_s_data,
Poisson, poisson, Binomial, Multinomial, multinomial, Poisson, poisson, Binomial, Multinomial, multinomial,
structured_sigmoid, structured_exp, structured_log, structured_sigmoid, structured_exp, structured_log,
structured_pow, structured_minimum, structured_maximum, structured_add, structured_pow, structured_minimum, structured_maximum, structured_add,
...@@ -1960,6 +1960,11 @@ class TestCast(utt.InferShapeTester): ...@@ -1960,6 +1960,11 @@ class TestCast(utt.InferShapeTester):
for t in self.compatible_types]) for t in self.compatible_types])
for x in self.x_csr]) for x in self.x_csr])
cast_csr_func = dict([
(x, [theano.function([x], cast(x, t))
for t in self.compatible_types])
for x in self.x_csr])
for x in self.x_csc: for x in self.x_csc:
for f, t in zip(cast_csc[x], self.compatible_types): for f, t in zip(cast_csc[x], self.compatible_types):
a = sp.csc_matrix(self.properties, dtype=x.dtype).copy() a = sp.csc_matrix(self.properties, dtype=x.dtype).copy()
...@@ -1970,6 +1975,11 @@ class TestCast(utt.InferShapeTester): ...@@ -1970,6 +1975,11 @@ class TestCast(utt.InferShapeTester):
a = sp.csr_matrix(self.properties, dtype=x.dtype) a = sp.csr_matrix(self.properties, dtype=x.dtype)
assert f(a).dtype == t assert f(a).dtype == t
for x in self.x_csr:
for f, t in zip(cast_csr_func[x], self.compatible_types):
a = sp.csr_matrix(self.properties, dtype=x.dtype)
assert f(a).dtype == t
def test_infer_shape(self): def test_infer_shape(self):
for x in self.x_csc: for x in self.x_csc:
for t in self.compatible_types: for t in self.compatible_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论