提交 e22eee61 authored 作者: Frederic's avatar Frederic

Add/test sparce.CSM.infer_shape.

上级 cdeeb509
...@@ -704,6 +704,14 @@ class CSM(gof.Op): ...@@ -704,6 +704,14 @@ class CSM(gof.Op):
g_data = csm_grad(self.kmap)(data, csm_data(g_out), csm_indices(g_out)) g_data = csm_grad(self.kmap)(data, csm_data(g_out), csm_indices(g_out))
return [g_data, None, None, None] return [g_data, None, None, None]
def infer_shape(self, node, shapes):
if self.kmap is None:
# node.inputs[3] is of lenght as we only support sparse matrix.
return [(node.inputs[3][0], node.inputs[3][1])]
else:
return node.env.shape_feature.default_infer_shape(node, shapes)
CSC = CSM('csc') CSC = CSM('csc')
CSR = CSM('csr') CSR = CSM('csr')
......
...@@ -153,6 +153,22 @@ class SparseInferShapeTester(utt.InferShapeTester): ...@@ -153,6 +153,22 @@ class SparseInferShapeTester(utt.InferShapeTester):
config.floatX, 3))], config.floatX, 3))],
GetItemScalar) GetItemScalar)
def test_csm(self):
for sparsetype in ('csr', 'csc'):
x = tensor.vector()
y = tensor.ivector()
z = tensor.ivector()
s = tensor.ivector()
call = getattr(sp, sparsetype + '_matrix')
spm = call(random_lil((300, 400), config.floatX, 5))
out = CSM(sparsetype)(x, y, z, s)
self._compile_and_check([x, y, z, s],
[out],
[spm.data, spm.indices, spm.indptr,
spm.shape],
CSM
)
def test_csm_grad(self): def test_csm_grad(self):
for sparsetype in ('csr', 'csc'): for sparsetype in ('csr', 'csc'):
x = tensor.vector() x = tensor.vector()
...@@ -262,10 +278,6 @@ class SparseInferShapeTester(utt.InferShapeTester): ...@@ -262,10 +278,6 @@ class SparseInferShapeTester(utt.InferShapeTester):
config.floatX, 3))], config.floatX, 3))],
StructuredDot) StructuredDot)
def test_csm(self):
# We also need the grad of CSM to be implemetned.
raise SkipTest('infer_shape not implemented for CSM')
def test_structured_dot_grad(self): def test_structured_dot_grad(self):
# We also need the grad of CSM to be implemetned. # We also need the grad of CSM to be implemetned.
raise SkipTest('infer_shape not implemented for the grad' raise SkipTest('infer_shape not implemented for the grad'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论