提交 c1876a37 authored 作者: Frederic's avatar Frederic

fix __getitem__ on shared sparse variable and test it.

No change in NEWS.txt as this was not released.
上级 69a5ddda
...@@ -4,7 +4,7 @@ from theano.compile import shared_constructor, SharedVariable ...@@ -4,7 +4,7 @@ from theano.compile import shared_constructor, SharedVariable
from theano import config from theano import config
from basic import SparseType, _sparse_py_operators from basic import SparseType, _sparse_py_operators
class SparseTensorSharedVariable(SharedVariable, _sparse_py_operators): class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
pass pass
@shared_constructor @shared_constructor
......
...@@ -1464,6 +1464,14 @@ class Test_getitem(unittest.TestCase): ...@@ -1464,6 +1464,14 @@ class Test_getitem(unittest.TestCase):
assert r11.shape == t11.shape assert r11.shape == t11.shape
assert numpy.all(r11.toarray() == t11.toarray()) assert numpy.all(r11.toarray() == t11.toarray())
# Test that is work with shared variable
sx = theano.shared(vx)
f12 = theano.function([a], sx[:, a:])
r12 = f12(p)
t12 = vx[:, p:]
assert r12.shape == t12.shape
assert numpy.all(r12.toarray() == t12.toarray())
#------------------------------------------------------------ #------------------------------------------------------------
# Invalid things # Invalid things
# The syntax is a bit awkward because assertRaises forbids # The syntax is a bit awkward because assertRaises forbids
...@@ -1526,6 +1534,14 @@ class Test_getitem(unittest.TestCase): ...@@ -1526,6 +1534,14 @@ class Test_getitem(unittest.TestCase):
assert r3.shape == t3.shape assert r3.shape == t3.shape
assert numpy.all(t4 == r4) assert numpy.all(t4 == r4)
# Test that is work with shared variable
sx = theano.shared(vx)
f1 = theano.function([a, b], sx[a, b])
r1 = f1(10, 10)
t1 = vx[10, 10]
assert r1.shape == t1.shape
assert numpy.all(t1 == r1)
import theano.tensor.tests.test_sharedvar import theano.tensor.tests.test_sharedvar
test_shared_options = theano.tensor.tests.test_sharedvar.makeSharedTester( test_shared_options = theano.tensor.tests.test_sharedvar.makeSharedTester(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论