提交 1d7175c7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Added tests with dimensions at 0 for dot22, dot22scalar and gemm.

上级 dd31b668
......@@ -1179,6 +1179,9 @@ def _approx_eq(a,b,eps=1.0e-4):
print a.shape, b.shape
return False
abs_rel_err = numeric_grad.abs_rel_err(a,b)
# numpy.max don't like empty ndarray.
if a.size == b.size == 0:
return True
if numpy.max(abs_rel_err) >= eps:
if _approx_eq.debug:
print a, b
......
......@@ -72,6 +72,8 @@ class t_gemm(TestCase):
self.assertTrue(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
elif z_orig.size == 0:
self.assertTrue(z.size==0)
else:
self.assertFalse(numpy.all(z_orig == z))
......@@ -125,6 +127,13 @@ class t_gemm(TestCase):
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_shape_0(self):
self.cmp(self.rand(0,4), -1.0, self.rand(0,5), self.rand(5,4), -1.0)
self.cmp(self.rand(3,0), -1.0, self.rand(3,5), self.rand(5,0), -1.0)
self.cmp(self.rand(3,4), -1.0, self.rand(3,0), self.rand(0,4), -1.0)
self.cmp(self.rand(0,0), -1.0, self.rand(0,5), self.rand(5,0), -1.0)
self.cmp(self.rand(0,0), -1.0, self.rand(0,0), self.rand(0,0), -1.0)
def test_factorised_scalar(self):
a=T.dmatrix()
b=T.dmatrix()
......@@ -671,9 +680,16 @@ def test_dot22():
assert _dot22 in [x.op for x in topo]
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
av=rng.uniform(size=(5,5)).astype(config.floatX)
bv=rng.uniform(size=(5,5)).astype(config.floatX)
f(av,bv)
def cmp(a_shp, b_shp):
av=rng.uniform(size=a_shp).astype(config.floatX)
bv=rng.uniform(size=b_shp).astype(config.floatX)
f(av,bv)
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,0),(0,5))
cmp((3,4),(4,0))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
def test_dot22scalar():
## including does not seem to work for 'local_dot_to_dot22' and
......@@ -686,62 +702,70 @@ def test_dot22scalar():
c=T.matrix()
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
av=rng.uniform(size=(5,5)).astype(config.floatX)
bv=rng.uniform(size=(5,5)).astype(config.floatX)
cv=rng.uniform(size=(5,5)).astype(config.floatX)
if True:
f = theano.function([a,b],0.2*T.dot(a,b),mode=mode_blas_opt)
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5,5)):
av=rng.uniform(size=a_shp).astype(config.floatX)
bv=rng.uniform(size=b_shp).astype(config.floatX)
cv=rng.uniform(size=c_shp).astype(config.floatX)
sv=rng.uniform(size=sqr_shp).astype(config.floatX)
if True:
f = theano.function([a,b],0.2*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==1
f(av,bv)
if True:
f = theano.function([a,b,c],0.2*c*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==1
f(av,bv)
assert len(topo)==2
f(av,bv,cv)
if True:
f = theano.function([a,b,c],0.2*c*T.dot(a,b),mode=mode_blas_opt)
## Here, canonicalize also seems needed
## TODO: add only the optimizations needed?
m2 = mode_blas_opt.including('canonicalize')
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
## Here, canonicalize also seems needed
## TODO: add only the optimizations needed?
m2 = mode_blas_opt.including('canonicalize')
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*a*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],0.2*c *a*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together...
# dot22scalar optimizer does not do a recursive search
# therefore, it doesn't find potential matches of the scalar.
# TODO: combine with the 'canonicalization' that is part of the Gemm optimizer.
#
# assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * a*0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*a*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(sv,sv,sv)
f = theano.function([a,b,c],0.2*c *a*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together...
# dot22scalar optimizer does not do a recursive search
# therefore, it doesn't find potential matches of the scalar.
# TODO: combine with the 'canonicalization' that is part of the Gemm optimizer.
#
# assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2
f(sv,sv,sv)
f = theano.function([a,b,c],c * a*0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(sv,sv,sv)
cmp((3,4),(4,5),(3,5))
cmp((0,4),(4,5),(0,5))
cmp((3,0),(0,5),(3,5))
cmp((3,4),(4,0),(3,0),(0,0))
cmp((0,4),(4,0),(0,0))
cmp((0,0),(0,0),(0,0))
def test_dot_w_self():
# This can trigger problems in the optimization because what would normally be a gemm must
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论