提交 7077bbfa authored 作者: Tanjay94's avatar Tanjay94

Made more change to GetItemList function.

上级 1744a5b1
...@@ -1019,6 +1019,7 @@ class GetItemList(gof.op.Op): ...@@ -1019,6 +1019,7 @@ class GetItemList(gof.op.Op):
ind = tensor.as_tensor_variable(index) ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1 assert ind.ndim == 1
assert "int" in ind.dtype
return gof.Apply(self, [x, ind], [x.type()]) return gof.Apply(self, [x, ind], [x.type()])
...@@ -1068,11 +1069,12 @@ class GetItemListGrad(gof.op.Op): ...@@ -1068,11 +1069,12 @@ class GetItemListGrad(gof.op.Op):
ind = tensor.as_tensor_variable(index) ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1 assert ind.ndim == 1
assert "int" in ind.dtype
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]] scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if not scipy_ver >= [0, 13]: if not scipy_ver >= [0, 13]:
raise NotImplemented("Scipy version is to old") raise NotImplementedError("Scipy version is to old")
return gof.Apply(self, [x, ind, gz], [x.type()]) return gof.Apply(self, [x, ind, gz], [x.type()])
...@@ -1109,6 +1111,8 @@ class GetItem2Lists(gof.op.Op): ...@@ -1109,6 +1111,8 @@ class GetItem2Lists(gof.op.Op):
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
ind1 = tensor.as_tensor_variable(ind1) ind1 = tensor.as_tensor_variable(ind1)
ind2 = tensor.as_tensor_variable(ind2) ind2 = tensor.as_tensor_variable(ind2)
assert "int" in ind1.dtype
assert "int" in ind2.dtype
return gof.Apply(self, [x, ind1, ind2], return gof.Apply(self, [x, ind1, ind2],
[theano.tensor.vector()]) [theano.tensor.vector()])
......
...@@ -2052,7 +2052,7 @@ class Test_getitem(unittest.TestCase): ...@@ -2052,7 +2052,7 @@ class Test_getitem(unittest.TestCase):
def op_with_fixed_index(x): def op_with_fixed_index(x):
return op(x, index=numpy.asarray([0, 1])) return op(x, index=numpy.asarray([0, 1]))
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64") x, x_val = sparse_random_inputs("csr", (4,5))
try: try:
verify_grad_sparse(op_with_fixed_index, x_val) verify_grad_sparse(op_with_fixed_index, x_val)
...@@ -2080,17 +2080,21 @@ class Test_getitem(unittest.TestCase): ...@@ -2080,17 +2080,21 @@ class Test_getitem(unittest.TestCase):
def test_GetItem2Lists_wrong_index(self): def test_GetItem2Lists_wrong_index(self):
a, A = sparse_random_inputs('csr', (4, 5)) a, A = sparse_random_inputs('csr', (4, 5))
y = a[0][[0, 4], [0, 4]] y1 = a[0][[0, 5], [0, 3]]
f = theano.function([a[0]], y) y2 = a[0][[0, 3], [0, 5]]
self.assertRaises(IndexError, f, A[0]) f1 = theano.function([a[0]], y1)
f2 = theano.function([a[0]], y2)
self.assertRaises(IndexError, f1, A[0])
self.assertRaises(IndexError, f2, A[0])
def test_get_item_2lists_grad(self): def test_get_item_2lists_grad(self):
op = theano.sparse.basic.GetItem2Lists() op = theano.sparse.basic.GetItem2Lists()
def op_with_fixed_index(x): def op_with_fixed_index(x):
return op(x, ind1=numpy.asarray([0, 1]), ind2=numpy.asarray([1, 1])) return op(x, ind1=numpy.asarray([0, 1]), ind2=numpy.asarray([2, 3]))
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64") x, x_val = sparse_random_inputs("csr", (4,5))
verify_grad_sparse(op_with_fixed_index, x_val) verify_grad_sparse(op_with_fixed_index, x_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论