提交 97f4fdfd authored 作者: Frederic Bastien's avatar Frederic Bastien

import fix in sparse sandbox.

上级 aa261d20
import unittest
import numpy
from theano import gof, tensor,compile
from theano.sparse.tests.test_basic import eval_outputs
from theano.sparse.basic import _is_sparse_variable, _is_dense_variable, as_sparse_variable, _is_sparse, _mtypes, _mtype_to_str
from theano.sparse import SparseType, dense_from_sparse, transpose
############### ###############
# #
...@@ -43,7 +52,7 @@ class TrueDot(gof.op.Op): ...@@ -43,7 +52,7 @@ class TrueDot(gof.op.Op):
raise NotImplementedError() raise NotImplementedError()
inputs = [x, y] # Need to convert? e.g. assparse inputs = [x, y] # Need to convert? e.g. assparse
outputs = [Sparse(dtype = x.type.dtype, format = myformat).make_variable()] outputs = [SparseType(dtype = x.type.dtype, format = myformat).make_variable()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
""" """
...@@ -194,7 +203,7 @@ class test_true_dot(unittest.TestCase): ...@@ -194,7 +203,7 @@ class test_true_dot(unittest.TestCase):
def test_graph_bprop0(self): def test_graph_bprop0(self):
for mtype in _mtypes: for mtype in _mtypes:
x = tensor.matrix('x') #TensorType('float64', broadcastable=[False,False], name='x') x = tensor.matrix('x') #TensorType('float64', broadcastable=[False,False], name='x')
w = Sparse(dtype = 'float64', format = _mtype_to_str[mtype]).make_variable() w = SparseType(dtype = 'float64', format = _mtype_to_str[mtype]).make_variable()
xw = dense_from_sparse(true_dot(w, x)) xw = dense_from_sparse(true_dot(w, x))
y = dense_from_sparse(true_dot(w.T, xw)) y = dense_from_sparse(true_dot(w.T, xw))
diff = x-y diff = x-y
...@@ -221,7 +230,7 @@ class test_true_dot(unittest.TestCase): ...@@ -221,7 +230,7 @@ class test_true_dot(unittest.TestCase):
xorig = numpy.random.rand(3,2) xorig = numpy.random.rand(3,2)
for mtype in _mtypes: for mtype in _mtypes:
x = tensor.matrix('x') x = tensor.matrix('x')
w = Sparse(dtype = 'float64', format = _mtype_to_str[mtype]).make_variable() w = SparseType(dtype = 'float64', format = _mtype_to_str[mtype]).make_variable()
xw = dense_from_sparse(true_dot(w, x)) xw = dense_from_sparse(true_dot(w, x))
y = dense_from_sparse(true_dot(w.T, xw)) y = dense_from_sparse(true_dot(w.T, xw))
diff = x-y diff = x-y
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论