提交 0d989978 authored 作者: Frederic Bastien's avatar Frederic Bastien

added test for the canonizer with dimshuffle.

上级 2dc01120
...@@ -265,21 +265,28 @@ class test_canonize(unittest.TestCase): ...@@ -265,21 +265,28 @@ class test_canonize(unittest.TestCase):
shp=(3,3) shp=(3,3)
fx, fy, fz, fw = fmatrices('xyzw') fx, fy, fz, fw = fmatrices('xyzw')
dx, dy, dz, dw = dmatrices('xyzw') dx, dy, dz, dw = dmatrices('xyzw')
fv = fvector('r').dimshuffle('x',0)
dv = dvector('s').dimshuffle('x',0)
fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fwv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') fwv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') fvv = numpy.asarray(numpy.random.rand(shp[0]),dtype='float32').reshape(1,shp[0])
dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float64')
dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float64')
dwv = numpy.asarray(numpy.random.rand(*shp),dtype='float32') dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float64')
dwv = numpy.asarray(numpy.random.rand(*shp),dtype='float64')
dvv = numpy.asarray(numpy.random.rand(shp[0]),dtype='float64').reshape(1,shp[0])
#we need the optimisation enabled, debug do this. #we need the optimisation enabled, debug do this.
mode=compile.mode.predefined_modes['DEBUG_MODE'] mode=compile.mode.predefined_modes['DEBUG_MODE']
#test x / x -> 1 #test x / x -> 1
for (g, sym_inputs, val_inputs, out_dtype) in [(fx/fx,[fx],[fxv],'float32'), for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([(fx/fx,[fx],[fxv],'float32'),
(dx/dx,[dx],[dxv],'float64')]: (dx/dx,[dx],[dxv],'float64'),
(fv/fv,[fv],[fvv],'float32'),
(dv/dv,[dv],[dvv],'float64'),
]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
...@@ -292,39 +299,62 @@ class test_canonize(unittest.TestCase): ...@@ -292,39 +299,62 @@ class test_canonize(unittest.TestCase):
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
#test (x * y) / x -> y #test (x * y) / x -> y
for (g, sym_inputs, val_inputs, out_dtype) in [ for id,(g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate([
((dx*dy)/dx,[dx,dy],[dxv,dyv],'float64'), ((dx*dy)/dx,[dx,dy],[dxv,dyv],0,'float64'),
((fx*fy)/fx,[fx,fy],[fxv,fyv],'float32') ((fx*fy)/fx,[fx,fy],[fxv,fyv],0,'float32'),
]: ((dv*dy)/dv,[dv,dy],[dvv,dyv],0,'float64'),
((fv*fy)/fv,[fv,fy],[fvv,fyv],0,'float32'),
#must broadcast as their is a dimshuffle in the computation
((dx*dv)/dx,[dx,dv],[dxv,dvv],1,'float64'),
#topo: [Elemwise{second,no_inplace}(x, <TensorType(float64, row)>)]
((fx*fv)/fx,[fx,fv],[fxv,fvv],1,'float32')
#topo: [Elemwise{second,no_inplace}(x, <TensorType(float32, row)>)]
]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
assert (out==val_inputs[1]).all() assert numpy.allclose(out,val_inputs[1])
topo=f.maker.env.toposort() topo=f.maker.env.toposort()
assert len(topo)==0 assert len(topo)==nb_elemwise
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
#test x / y / x -> 1 / y #test x / y / x -> 1 / y
for (g, sym_inputs, val_inputs, out_dtype) in [ for id,(g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate([
((dx/dy)/dx,[dx,dy],[dxv,dyv],'float64'), ((dx/dy)/dx,[dx,dy],[dxv,dyv],1,'float64'),
((fx/fy)/fx,[fx,fy],[fxv,fyv],'float32') ((fx/fy)/fx,[fx,fy],[fxv,fyv],1,'float32'),
]: ((dv/dy)/dv,[dv,dy],[dvv,dyv],1,'float64'),
((fv/fy)/fv,[fv,fy],[fvv,fyv],1,'float32'),
#must broadcast as their is a dimshuffle in the computation
((dx/dv)/dx,[dx,dv],[dxv,dvv],2,'float64'),
#topo: [Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Elemwise{second,no_inplace}(x, Elemwise{inv,no_inplace}.0)]
((fx/fv)/fx,[fx,fv],[fxv,fvv],2,'float32'),
#topo:[Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Elemwise{second,no_inplace}(x, Elemwise{inv,no_inplace}.0)]
]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
assert numpy.allclose(out,(1/val_inputs[1])) assert numpy.allclose(out,(1/val_inputs[1]))
topo=f.maker.env.toposort() topo=f.maker.env.toposort()
assert len(topo)==1 assert len(topo)==nb_elemwise
assert isinstance(topo[0].op,(T.Elemwise,)) assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Inv) assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Inv)
assert len(topo[0].inputs)==1 assert len(topo[0].inputs)==1
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
#test (a / b) * (b / c) * (c / d) -> a / d #test (a / b) * (b / c) * (c / d) -> a / d
for (g, sym_inputs, val_inputs, out_dtype) in [ for id,(g, sym_inputs, val_inputs, out_dtype) in enumerate([
((dx / dy) * (dy / dz) * (dz / dw),[dx,dy,dz,dw],[dxv,dyv,dzv,dwv],'float64'), ((dx / dy) * (dy / dz) * (dz / dw),[dx,dy,dz,dw],[dxv,dyv,dzv,dwv],'float64'),
((fx / fy) * (fy / fz) * (fz / fw),[fx,fy,fz,fw],[fxv,fyv,fzv,fwv],'float32') ((fx / fy) * (fy / fz) * (fz / fw),[fx,fy,fz,fw],[fxv,fyv,fzv,fwv],'float32'),
]: ((dv / dy) * (dy / dz) * (dz / dw),[dv,dy,dz,dw],[dvv,dyv,dzv,dwv],'float64'),
((fv / fy) * (fy / fz) * (fz / fw),[fv,fy,fz,fw],[fvv,fyv,fzv,fwv],'float32'),
((dx / dv) * (dv / dz) * (dz / dw),[dx,dv,dz,dw],[dxv,dvv,dzv,dwv],'float64'),
((fx / fv) * (fv / fz) * (fz / fw),[fx,fv,fz,fw],[fxv,fvv,fzv,fwv],'float32'),
((dx / dy) * (dy / dv) * (dv / dw),[dx,dy,dv,dw],[dxv,dyv,dvv,dwv],'float64'),
((fx / fy) * (fy / fv) * (fv / fw),[fx,fy,fv,fw],[fxv,fyv,fvv,fwv],'float32'),
((dx / dy) * (dy / dz) * (dz / dv),[dx,dy,dz,dv],[dxv,dyv,dzv,dvv],'float64'),
((fx / fy) * (fy / fz) * (fz / fv),[fx,fy,fz,fv],[fxv,fyv,fzv,fvv],'float32'),
]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
...@@ -337,10 +367,14 @@ class test_canonize(unittest.TestCase): ...@@ -337,10 +367,14 @@ class test_canonize(unittest.TestCase):
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
#test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y #test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
for (g, sym_inputs, val_inputs, out_dtype) in [ for id,(g, sym_inputs, val_inputs, out_dtype) in enumerate([
(((2.0*dx)/(4.0*dy)),[dx,dy],[dxv,dyv],'float64'), (((2.0*dx)/(4.0*dy)),[dx,dy],[dxv,dyv],'float64'),
(((2.0*fx)/(4.0*fy)),[fx,fy],[fxv,fyv],'float32'), (((2.0*fx)/(4.0*fy)),[fx,fy],[fxv,fyv],'float32'),
]: (((2.0*dv)/(4.0*dy)),[dv,dy],[dvv,dyv],'float64'),
(((2.0*fv)/(4.0*fy)),[fv,fy],[fvv,fyv],'float32'),
(((2.0*dx)/(4.0*dv)),[dx,dv],[dxv,dvv],'float64'),
(((2.0*fx)/(4.0*fv)),[fx,fv],[fxv,fvv],'float32'),
]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
...@@ -356,14 +390,16 @@ class test_canonize(unittest.TestCase): ...@@ -356,14 +390,16 @@ class test_canonize(unittest.TestCase):
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
#test 2 * x / 2 -> x #test 2 * x / 2 -> x
for (g, sym_inputs, val_inputs, out_dtype) in [ for id,(g, sym_inputs, val_inputs, out_dtype) in enumerate([
((2*dx)/2,[dx],[dxv],'float64'), ((2*dx)/2,[dx],[dxv],'float64'),
((2*fx)/2,[fx],[fxv],'float32'), ((2*fx)/2,[fx],[fxv],'float32'),
]: ((2*dv)/2,[dv],[dvv],'float64'),
((2*fv)/2,[fv],[fvv],'float32'),
]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
assert (out==val_inputs[0]).all() assert numpy.allclose(out,val_inputs[0])
topo=f.maker.env.toposort() topo=f.maker.env.toposort()
assert len(topo)==0 assert len(topo)==0
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论