提交 7077bbfa authored 作者: Tanjay94's avatar Tanjay94

Made more change to GetItemList function.

上级 1744a5b1
......@@ -1019,6 +1019,7 @@ class GetItemList(gof.op.Op):
ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1
assert "int" in ind.dtype
return gof.Apply(self, [x, ind], [x.type()])
......@@ -1068,11 +1069,12 @@ class GetItemListGrad(gof.op.Op):
ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1
assert "int" in ind.dtype
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if not scipy_ver >= [0, 13]:
raise NotImplemented("Scipy version is to old")
raise NotImplementedError("Scipy version is to old")
return gof.Apply(self, [x, ind, gz], [x.type()])
......@@ -1109,6 +1111,8 @@ class GetItem2Lists(gof.op.Op):
assert x.format in ["csr", "csc"]
ind1 = tensor.as_tensor_variable(ind1)
ind2 = tensor.as_tensor_variable(ind2)
assert "int" in ind1.dtype
assert "int" in ind2.dtype
return gof.Apply(self, [x, ind1, ind2],
[theano.tensor.vector()])
......
......@@ -2052,7 +2052,7 @@ class Test_getitem(unittest.TestCase):
def op_with_fixed_index(x):
return op(x, index=numpy.asarray([0, 1]))
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64")
x, x_val = sparse_random_inputs("csr", (4,5))
try:
verify_grad_sparse(op_with_fixed_index, x_val)
......@@ -2080,17 +2080,21 @@ class Test_getitem(unittest.TestCase):
def test_GetItem2Lists_wrong_index(self):
a, A = sparse_random_inputs('csr', (4, 5))
y = a[0][[0, 4], [0, 4]]
f = theano.function([a[0]], y)
y1 = a[0][[0, 5], [0, 3]]
y2 = a[0][[0, 3], [0, 5]]
self.assertRaises(IndexError, f, A[0])
f1 = theano.function([a[0]], y1)
f2 = theano.function([a[0]], y2)
self.assertRaises(IndexError, f1, A[0])
self.assertRaises(IndexError, f2, A[0])
def test_get_item_2lists_grad(self):
op = theano.sparse.basic.GetItem2Lists()
def op_with_fixed_index(x):
return op(x, ind1=numpy.asarray([0, 1]), ind2=numpy.asarray([1, 1]))
return op(x, ind1=numpy.asarray([0, 1]), ind2=numpy.asarray([2, 3]))
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64")
x, x_val = sparse_random_inputs("csr", (4,5))
verify_grad_sparse(op_with_fixed_index, x_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论