提交 3c5882e1 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -67,7 +67,7 @@ the most important ones: ...@@ -67,7 +67,7 @@ the most important ones:
- When the computations are done, transfer the results from the C - When the computations are done, transfer the results from the C
structure we put them in to the destination Python object. This structure we put them in to the destination Python object. This
will only be called for the inputs. will only be called for the outputs.
- **c_cleanup(name, sub)** - **c_cleanup(name, sub)**
......
...@@ -1091,7 +1091,11 @@ class Module(ComponentDict): ...@@ -1091,7 +1091,11 @@ class Module(ComponentDict):
directly. directly.
""" """
InstanceType = ModuleInstance # By default, we use build ModuleInstance InstanceType = ModuleInstance # By default, we use build ModuleInstance
def __init__(self):
super(Module, self).__init__()
self.__dict__["local_attr"]={}
def __wrapper__(self, x): def __wrapper__(self, x):
""" """
This function is called whenever x is set as an attribute of This function is called whenever x is set as an attribute of
...@@ -1139,12 +1143,8 @@ class Module(ComponentDict): ...@@ -1139,12 +1143,8 @@ class Module(ComponentDict):
# raise NotImplementedError # raise NotImplementedError
# print "WARNING: unknow:",v # print "WARNING: unknow:",v
return v return v
value=unpack_member_and_external(value)
if not hasattr(self,"local_attr"):
self.__dict__["local_attr"]={}
self.__dict__["local_attr"][attr] = value self.__dict__["local_attr"][attr] = unpack_member_and_external(value)
def build(self, mode, memo): def build(self, mode, memo):
if self in memo: if self in memo:
......
...@@ -14,6 +14,10 @@ import theano ...@@ -14,6 +14,10 @@ import theano
#TODO: add test for module.make(member=init_value) #TODO: add test for module.make(member=init_value)
class T_module(unittest.TestCase): class T_module(unittest.TestCase):
def test_empty_module(self):
m = Module()
m.make()
def test_whats_up_with_submembers(self): def test_whats_up_with_submembers(self):
class Blah(Module): class Blah(Module):
def __init__(self, stepsize): def __init__(self, stepsize):
...@@ -504,6 +508,7 @@ def test_tuple_members(): ...@@ -504,6 +508,7 @@ def test_tuple_members():
class Temp(Module): class Temp(Module):
def __init__(self): def __init__(self):
super(Temp, self).__init__()
self.a = (1,1) self.a = (1,1)
M = Temp() M = Temp()
assert isinstance(M.a, tuple) assert isinstance(M.a, tuple)
......
...@@ -871,20 +871,20 @@ sd_csc = StructuredDotCSC() ...@@ -871,20 +871,20 @@ sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op): class StructuredDotCSR(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, a_ncols, b): def make_node(self, a_val, a_ind, a_ptr, b):
assert a_val.type.dtype == b.type.dtype assert a_val.type.dtype == b.type.dtype
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_ncols, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
[tensor.tensor(a_val.type.dtype, (False, False))]) [tensor.tensor(a_val.type.dtype, (False, False))])
return r return r
def perform(self, node, (a_val, a_ind, a_ptr, a_ncols, b), (out,)): def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)):
a = sparse.csr_matrix((a_val, a_ind, a_ptr), a = sparse.csr_matrix((a_val, a_ind, a_ptr),
(a_ncols, b.shape[0]), (len(a_ptr)-1, b.shape[0]),
copy = False) copy = True) #use view_map before setting this to False
out[0] = a.dot(b) out[0] = a.dot(b)
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense, but not .6 sometimes
def c_code(self, node, name, (a_val, a_ind, a_ptr, a_ncols, b), (z,), sub): def c_code(self, node, name, (a_val, a_ind, a_ptr, b), (z,), sub):
""" """
C-implementation of the dot product of the sparse matrix A and matrix B. C-implementation of the dot product of the sparse matrix A and matrix B.
@param a_val: non-zero values of the sparse matrix @param a_val: non-zero values of the sparse matrix
...@@ -899,7 +899,6 @@ class StructuredDotCSR(gof.Op): ...@@ -899,7 +899,6 @@ class StructuredDotCSR(gof.Op):
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
if (%(a_ptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} if (%(a_ptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;}
if (%(a_ncols)s->nd != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(ncols) != 0"); %(fail)s;}
if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(a_val)s->descr->type_num != PyArray_DOUBLE) if (%(a_val)s->descr->type_num != PyArray_DOUBLE)
...@@ -911,26 +910,20 @@ class StructuredDotCSR(gof.Op): ...@@ -911,26 +910,20 @@ class StructuredDotCSR(gof.Op):
if (%(a_ptr)s->descr->type_num != PyArray_INT32) if (%(a_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;}
if (%(a_ncols)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ncols dtype not INT32"); %(fail)s;}
if (%(b)s->descr->type_num != PyArray_DOUBLE) if (%(b)s->descr->type_num != PyArray_DOUBLE)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;}
if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0]) if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;}
if (%(a_ptr)s->dimensions[0] != %(b)s->dimensions[0]+1)
{PyErr_SetString(PyExc_NotImplementedError, "a's number of columns doesn't match b's rows"); %(fail)s;}
if ((!%(z)s) if ((!%(z)s)
|| (%(z)s->dimensions[0] != ((npy_int32 *)%(a_ncols)s->data)[0]) || (%(z)s->dimensions[0] != %(a_ptr)s->dimensions[0]-1) //a's rows
|| (%(z)s->dimensions[1] != %(b)s->dimensions[1]) || (%(z)s->dimensions[1] != %(b)s->dimensions[1]) //b's columns
) )
{ {
if (%(z)s) Py_DECREF(%(z)s); if (%(z)s) Py_DECREF(%(z)s);
npy_intp dims[] = {0,0}; npy_intp dims[] = {0,0};
dims[0] = ((npy_int32 *)%(a_ncols)s->data)[0]; dims[0] = %(a_ptr)s->dimensions[0]-1;
dims[1] = %(b)s->dimensions[1]; dims[1] = %(b)s->dimensions[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num); %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num);
} }
...@@ -1013,11 +1006,13 @@ sd_csr = StructuredDotCSR() ...@@ -1013,11 +1006,13 @@ sd_csr = StructuredDotCSR()
def local_structured_dot(node): def local_structured_dot(node):
if node.op == _structured_dot: if node.op == _structured_dot:
a, b = node.inputs a, b = node.inputs
if a.type.format in ('csc','csr'): if a.type.format == 'csc':
a_val, a_ind, a_ptr, a_shape = csm_properties(a) a_val, a_ind, a_ptr, a_shape = csm_properties(a)
a_nsparse = a_shape[0] a_nsparse = a_shape[0]
sd_csx = sd_csc if a.type.format == 'csc' else sd_csr return [sd_csc(a_val, a_ind, a_ptr, a_nsparse, b)]
return [sd_csx(a_val,a_ind, a_ptr, a_nsparse, b)] if a.type.format == 'csr':
a_val, a_ind, a_ptr, a_shape = csm_properties(a)
return [sd_csr(a_val, a_ind, a_ptr, b)]
return False return False
register_specialize(local_structured_dot) register_specialize(local_structured_dot)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论