提交 7943678b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Accept other Variable Types in sparse check

Add a test case that was reported on theano-users.
上级 c73dd4d6
......@@ -43,25 +43,27 @@ _mtype_to_str = {scipy.sparse.csc_matrix: "csc",
def _is_sparse_variable(x):
"""
@rtype: boolean
@return: True iff x is a L{SparseVariable} (and not a L{tensor.TensorType})
@return: True iff x is a L{SparseVariable} (and not a L{tensor.TensorType},
for instance)
"""
if not isinstance(x.type, (SparseType, tensor.TensorType)):
if not isinstance(x, gof.Variable):
raise NotImplementedError("this function should only be called on "
"*variables* (of type sparse.SparseType "
"or tensor.TensorType), not,", x)
"or tensor.TensorType, for instance), not ",
x)
return isinstance(x.type, SparseType)
def _is_dense_variable(x):
"""
@rtype: boolean
@return: True unless x is a L{SparseVariable} (and not a
L{tensor.TensorType})
@return: True if x is a L{tensor.TensorType} (and not a
L{SparseVariable}, for instance)
"""
if not isinstance(x.type, (SparseType, tensor.TensorType)):
if not isinstance(x, gof.Variable):
raise NotImplementedError("this function should only be called on "
"*variables* (of type sparse.SparseType or "
"tensor.TensorType), not,", x)
"tensor.TensorType, for instance), not ", x)
return isinstance(x.type, tensor.TensorType)
......@@ -3073,8 +3075,22 @@ class Dot(gof.op.Op):
def make_node(self, x, y):
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype)
if not _is_sparse_variable(x) and not _is_sparse_variable(y):
raise TypeError(x)
# Sparse dot product should have at least one sparse variable
# as input. If the other one is not sparse, it has to be converted
# into a tensor.
x_is_sparse_var = _is_sparse_variable(x)
y_is_sparse_var = _is_sparse_variable(y)
if not x_is_sparse_var and not y_is_sparse_var:
raise TypeError("Sparse dot product should have at least one "
"sparse variable as inputs, but the inputs are "
"%s (%s) and %s (%s)." % (x, x.type, y, y.type))
if not x_is_sparse_var:
x = tensor.as_tensor_variable(x)
if not y_is_sparse_var:
y = tensor.as_tensor_variable(y)
return gof.Apply(self, [x, y], [tensor.tensor(dtype=dtype_out,
broadcastable=(False, False))])
......
......@@ -1120,6 +1120,22 @@ class DotTests(unittest.TestCase):
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense))
for node in topo]) == nb
def test_cuda(self):
import theano.sandbox.cuda as cuda
if not cuda.cuda_available:
raise SkipTest("Optional package cuda not available")
a = sparse.csr_matrix('a', dtype='float32')
b = cuda.float32_shared_constructor(
numpy.random.rand(3, 4).astype('float32'))
d = sparse.dot(a, b)
f = theano.function([a], d)
a_val = scipy.sparse.csr_matrix(random_lil((5, 3), 'float32', 5))
d_theano = f(a_val)
d_numpy = a_val * b.get_value()
assert numpy.allclose(d_theano, d_numpy)
class UsmmTests(unittest.TestCase):
""" Test the Usmm and UsmmCscDense class and related optimization """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论