提交 42cb4b75 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

Add optimization tests.

上级 3327d4ad
......@@ -541,5 +541,72 @@ class SamplingDotTester(utt.InferShapeTester):
return S2.sampling_dot(x, y, self.a[2])
verify_grad_sparse(_helper, self.a[:2])
# ##################
# Optimization tests
# ##################
def test_local_mul_s_d():
mode = theano.compile.mode.get_default_mode()
mode = mode.including("specialize", "local_mul_s_d")
for sp_format in sparse.sparse_formats:
inputs = [getattr(theano.sparse, sp_format + '_matrix')(),
tensor.matrix()]
f = theano.function(inputs,
sparse.mul_s_d(*inputs),
mode=mode)
assert not any(isinstance(node.op, sparse.MulSD) for node
in f.maker.env.toposort())
def test_local_mul_s_v():
mode = theano.compile.mode.get_default_mode()
mode = mode.including("specialize", "local_mul_s_v")
for sp_format in ['csr']: # Not implemented for other format
inputs = [getattr(theano.sparse, sp_format + '_matrix')(),
tensor.vector()]
f = theano.function(inputs,
S2.mul_s_v(*inputs),
mode=mode)
assert not any(isinstance(node.op, S2.MulSV) for node
in f.maker.env.toposort())
def test_local_structured_add_s_v():
mode = theano.compile.mode.get_default_mode()
mode = mode.including("specialize", "local_structured_add_s_v")
for sp_format in ['csr']: # Not implemented for other format
inputs = [getattr(theano.sparse, sp_format + '_matrix')(),
tensor.vector()]
f = theano.function(inputs,
S2.structured_add_s_v(*inputs),
mode=mode)
assert not any(isinstance(node.op, S2.StructuredAddSV) for node
in f.maker.env.toposort())
def test_local_sampling_dot_csr():
mode = theano.compile.mode.get_default_mode()
mode = mode.including("specialize", "local_sampling_dot_csr")
for sp_format in ['csr']: # Not implemented for other format
inputs = [tensor.matrix(),
tensor.matrix(),
getattr(theano.sparse, sp_format + '_matrix')()]
f = theano.function(inputs,
S2.sampling_dot(*inputs),
mode=mode)
assert not any(isinstance(node.op, S2.SamplingDot) for node
in f.maker.env.toposort())
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论