提交 5c4a7c72 authored 作者: Frederic Bastien's avatar Frederic Bastien

added test for gemm insertion when the scalars alpha and beta are equal and factorized.

上级 1e2cf8e3
......@@ -98,6 +98,34 @@ class t_gemm(TestCase):
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_factorised_scalar(self):
a=T.matrix()
b=T.matrix()
c=T.matrix()
s=theano.shared(numpy.zeros((5,5)))
lr1=T.constant(0.01)
lr2=T.constant(2)
l2_reg=T.constant(0.0001)
#test constant merge with gemm
f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s}).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)]
assert len(f)==1
assert f[0].op==gemm_inplace
#test factored scalar with merge
f = theano.function([a,b],updates={s:lr1*(T.dot(a,b)-l2_reg*s)}).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, -2e-06)]
assert len(f)==1
assert f[0].op==gemm_inplace
#test factored scalar with merge and neg
f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))}).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 0.999998)]
assert len(f)==1
assert f[0].op==gemm_inplace
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor_variable(self.rand(2,2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论