提交 f860c915 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

bug in sparse matrix

上级 25da3fc6
...@@ -11,6 +11,7 @@ import numpy, theano ...@@ -11,6 +11,7 @@ import numpy, theano
import scipy.sparse import scipy.sparse
from theano.printing import Print from theano.printing import Print
from theano import gof from theano import gof
from theano import tensor from theano import tensor
from theano import compile from theano import compile
...@@ -532,6 +533,7 @@ class Transpose(gof.op.Op): ...@@ -532,6 +533,7 @@ class Transpose(gof.op.Op):
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x.transpose() out[0] = x.transpose()
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(gz) assert _is_sparse_variable(x) and _is_sparse_variable(gz)
return transpose(gz), return transpose(gz),
...@@ -738,7 +740,7 @@ class StructuredDot(gof.Op): ...@@ -738,7 +740,7 @@ class StructuredDot(gof.Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, a, b): def make_node(self, a, b):
if type(a) is not SparseVariable and type(a) is not SparseConstant: if not _is_sparse_variable(a):
raise TypeError('First argument must be of type SparseVariable or SparseConstant'); raise TypeError('First argument must be of type SparseVariable or SparseConstant');
dtype_out = scalar.upcast(a.type.dtype, b.type.dtype) dtype_out = scalar.upcast(a.type.dtype, b.type.dtype)
if b.type.ndim != 2: if b.type.ndim != 2:
...@@ -792,7 +794,6 @@ def structured_dot(x, y): ...@@ -792,7 +794,6 @@ def structured_dot(x, y):
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y) y_is_sparse_variable = _is_sparse_variable(y)
if not x_is_sparse_variable and not y_is_sparse_variable: if not x_is_sparse_variable and not y_is_sparse_variable:
raise TypeError('structured_dot requires at least one sparse argument') raise TypeError('structured_dot requires at least one sparse argument')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论