提交 618a863f authored 作者: Frederic Bastien's avatar Frederic Bastien

added test for the Canonizer.

上级 6d348aa4
...@@ -237,6 +237,177 @@ class test_canonize(unittest.TestCase): ...@@ -237,6 +237,177 @@ class test_canonize(unittest.TestCase):
out = f(*val_inputs) out = f(*val_inputs)
assert(len(f.maker.env.toposort())==expected_out_nb_elemwise) assert(len(f.maker.env.toposort())==expected_out_nb_elemwise)
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
def test_multiple_case(self):
""" test those case take from the comment in Canonizer
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
with and without DimShuffle
TODO: with DimShuffle
"""
import theano.tensor, theano.compile
shp=(3,3)
fx, fy, fz, fw = fmatrices('xyzw')
dx, dy, dz, dw = dmatrices('xyzw')
fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fwv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dwv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
#we need the optimisation enabled, debug do this.
mode=compile.mode.predefined_modes['DEBUG_MODE']
#test x / x -> 1
for (g, sym_inputs, val_inputs, out_dtype) in [(fx/fx,[fx],[fxv],'float32'),
(dx/dx,[dx],[dxv],'float64')]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert (out==numpy.ones(shp, dtype=out_dtype)).all()
topo=f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Second)
assert len(topo[0].inputs)==2
assert(out_dtype==out.dtype)
#test (x * y) / x -> y
for (g, sym_inputs, val_inputs, out_dtype) in [
((dx*dy)/dx,[dx,dy],[dxv,dyv],'float64'),
((fx*fy)/fx,[fx,fy],[fxv,fyv],'float32')
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert (out==val_inputs[1]).all()
topo=f.maker.env.toposort()
assert len(topo)==0
assert(out_dtype==out.dtype)
#test x / y / x -> 1 / y
for (g, sym_inputs, val_inputs, out_dtype) in [
((dx/dy)/dx,[dx,dy],[dxv,dyv],'float64'),
((fx/fy)/fx,[fx,fy],[fxv,fyv],'float32')
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert numpy.allclose(out,(1/val_inputs[1]))
topo=f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Inv)
assert len(topo[0].inputs)==1
assert(out_dtype==out.dtype)
#test (a / b) * (b / c) * (c / d) -> a / d
for (g, sym_inputs, val_inputs, out_dtype) in [
((dx / dy) * (dy / dz) * (dz / dw),[dx,dy,dz,dw],[dxv,dyv,dzv,dwv],'float64'),
((fx / fy) * (fy / fz) * (fz / fw),[fx,fy,fz,fw],[fxv,fyv,fzv,fwv],'float32')
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert numpy.allclose(out,(val_inputs[0]/val_inputs[3]))
topo=f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.TrueDiv)
assert len(topo[0].inputs)==2
assert(out_dtype==out.dtype)
#test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
for (g, sym_inputs, val_inputs, out_dtype) in [
(((2.0*dx)/(4.0*dy)),[dx,dy],[dxv,dyv],'float64'),
(((2.0*fx)/(4.0*fy)),[fx,fy],[fxv,fyv],'float32'),
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert numpy.allclose(out,(0.5*val_inputs[0]/val_inputs[1]))
topo=f.maker.env.toposort()
assert len(topo)==2
assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Mul)
assert len(topo[0].inputs)==2
assert isinstance(topo[1].op,(T.Elemwise,))
assert isinstance(topo[1].op.scalar_op,theano.scalar.basic.TrueDiv)
assert len(topo[1].inputs)==2
assert(out_dtype==out.dtype)
#test 2 * x / 2 -> x
for (g, sym_inputs, val_inputs, out_dtype) in [
((2*dx)/2,[dx],[dxv],'float64'),
((2*fx)/2,[fx],[fxv],'float32'),
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert (out==val_inputs[0]).all()
topo=f.maker.env.toposort()
assert len(topo)==0
assert(out_dtype==out.dtype)
def test_multiple_case_that_fail(self):
import theano.tensor, theano.compile
shp=(4,4)
fx, fy, fz = fmatrices('xyz')
dx, dy, dz = dmatrices('xyz')
fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fvv = numpy.asarray(numpy.random.rand(shp[0]),dtype='float32').reshape(1,shp[0])
mode=compile.mode.predefined_modes['DEBUG_MODE']
#test fail!
#test x / y / z -> x / (y * z)
for (g, sym_inputs, val_inputs, out_dtype) in [
((dx/dy)/dz,[dx,dy,dz],[dxv,dyv,dzv],'float64'),
((fx/fy)/fz,[fx,fy,fz],[fxv,fyv,fzv],'float32')
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert numpy.allclose(out,val_inputs[0]/val_inputs[1]/val_inputs[2])
topo=f.maker.env.toposort()
print topo
assert len(topo)==2
assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Inv)
assert len(topo[0].inputs)==1
assert(out_dtype==out.dtype)
#test x / (y / z) -> (x * z) / y
for (g, sym_inputs, val_inputs, out_dtype) in [
(dx/(dy/dz),[dx,dy,dz],[dxv,dyv,dzv],'float64'),
(fx/(fy/fz),[fx,fy,fz],[fxv,fyv,fzv],'float32')
]:
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
assert numpy.allclose(out,val_inputs[0]/(val_inputs[1]/val_inputs[2]))
topo=f.maker.env.toposort()
print topo
assert len(topo)==2
assert isinstance(topo[0].op,(T.Elemwise,))
assert isinstance(topo[0].op.scalar_op,theano.scalar.basic.Inv)
assert len(topo[0].inputs)==1
assert(out_dtype==out.dtype)
def test_mixeddiv(): def test_mixeddiv():
"""Test that int division is preserved""" """Test that int division is preserved"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论