提交 985e62a0 authored 作者: Frederic Bastien's avatar Frederic Bastien

added some test case for the optimisation Canonizer for elemwise.

上级 38f08697
...@@ -185,14 +185,58 @@ class test_canonize(unittest.TestCase): ...@@ -185,14 +185,58 @@ class test_canonize(unittest.TestCase):
""" """
verify that the Canonizer merge sequential Elemwise({mul,add}) verify that the Canonizer merge sequential Elemwise({mul,add})
""" """
x, y, z = matrices('xyz') shp=(5,5)
for g,n in [ fx, fy, fz = fmatrices('xyz')
(x+y+z,1), dx, dy, dz = dmatrices('xyz')
(x*y*z,1), fv = fvector('r').dimshuffle('x',0)
(x*y*(x+y+z),2), fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
]: fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
f = compile.function([x,y,z], g, mode=compile.Mode(optimizer='fast_run')) fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
assert(len(f.maker.env.toposort())==n) dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fvv = numpy.asarray(numpy.random.rand(shp[0]),dtype='float32').reshape(1,shp[0])
cases = [
(fx+fy,(fx,fy),(fxv,fyv),1,'float32'),
(fx*fy,(fx,fy),(fxv,fyv),1,'float32'),
(fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
(fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
(fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
(fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type add
(dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type mul
#check with dimshuffle of constant
(fx+fy+fz+2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(fx*fy*fz*2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2+fx+fy+fz+2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2*fx*fy*fz*2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx*fy*2*(fx+fy+fz+2),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#check with broadcast of row
(fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
(fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
(fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
]#[10:11]
# print cases
for id, [g, sym_inputs, val_inputs, expected_out_nb_elemwise, out_dtype] in enumerate(cases):
f = compile.function(list(sym_inputs), g,
#we need the optimisation enabled, debug do this.
mode=compile.mode.predefined_modes['DEBUG_MODE'])
out = f(*val_inputs)
assert(len(f.maker.env.toposort())==expected_out_nb_elemwise)
assert(out_dtype==out.dtype)
def test_mixeddiv(): def test_mixeddiv():
"""Test that int division is preserved""" """Test that int division is preserved"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论