提交 043c3eef authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Some pep8

上级 7215c905
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
# PENDING REWRITE OF tensor_opt.py
import copy import copy
import logging import logging
import pickle
import os import os
import sys import sys
import time import time
...@@ -13,8 +11,6 @@ import numpy ...@@ -13,8 +11,6 @@ import numpy
from six.moves import xrange from six.moves import xrange
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.tools import assert_raises, assert_true from nose.tools import assert_raises, assert_true
from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest
import theano import theano
import theano.scalar as scal import theano.scalar as scal
...@@ -43,15 +39,14 @@ from theano.tensor.opt import ( ...@@ -43,15 +39,14 @@ from theano.tensor.opt import (
Assert, Assert,
MakeVector, MakeVector,
make_vector, make_vector,
local_expm1,
local_canonicalize_alloc local_canonicalize_alloc
) )
from theano import tensor from theano import tensor
from theano import tensor as T from theano import tensor as T
from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar
from theano.tensor import vector, ivector, lvector, fvector, dvector from theano.tensor import vector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix, tensor3 from theano.tensor import matrix, fmatrix, dmatrix, tensor3
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices from theano.tensor import vectors, matrices, fmatrices, dmatrices
from theano.tensor import ( from theano.tensor import (
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
...@@ -69,8 +64,6 @@ from theano.tensor import ( ...@@ -69,8 +64,6 @@ from theano.tensor import (
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tensor.type import values_eq_approx_remove_nan from theano.tensor.type import values_eq_approx_remove_nan
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.compile.mode import optdb
from theano.compile import Mode
from theano.gof.opt import check_stack_trace, out2in from theano.gof.opt import check_stack_trace, out2in
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
...@@ -79,7 +72,6 @@ if mode_opt == 'FAST_COMPILE': ...@@ -79,7 +72,6 @@ if mode_opt == 'FAST_COMPILE':
mode_opt = 'FAST_RUN' mode_opt = 'FAST_RUN'
mode_opt = theano.compile.mode.get_mode(mode_opt) mode_opt = theano.compile.mode.get_mode(mode_opt)
ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
dimshuffle_lift = out2in(local_dimshuffle_lift) dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = gof.Query(include=['fast_run']) _optimizer_stabilize = gof.Query(include=['fast_run'])
...@@ -94,6 +86,10 @@ _optimizer_fast_run = gof.Query(include=['fast_run']) ...@@ -94,6 +86,10 @@ _optimizer_fast_run = gof.Query(include=['fast_run'])
_optimizer_fast_run = compile.optdb.query(_optimizer_fast_run) _optimizer_fast_run = compile.optdb.query(_optimizer_fast_run)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
def optimize(g, level='fast_run'): def optimize(g, level='fast_run'):
if level == 'fast_run': if level == 'fast_run':
_optimizer_fast_run.optimize(g) _optimizer_fast_run.optimize(g)
...@@ -138,8 +134,8 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -138,8 +134,8 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e])
self.assertTrue(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}" self.assertTrue(str(g) == ("[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x)))]", "(InplaceDimShuffle{0,x,1}(x)))]"),
str(g)) str(g))
dimshuffle_lift.optimize(g) dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]", str(g)) self.assertTrue(str(g) == "[x]", str(g))
...@@ -259,7 +255,7 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -259,7 +255,7 @@ def test_local_useless_dimshuffle_in_reshape():
h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h) str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h) useless_dimshuffle_in_reshape.optimize(h)
assert_true(str(h) == str(h)) assert_true(str(h) == str_h)
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
...@@ -269,6 +265,7 @@ def test_add_canonizer_problem0(): ...@@ -269,6 +265,7 @@ def test_add_canonizer_problem0():
r = segment_labels * 5 r = segment_labels * 5
f = function([label], r) f = function([label], r)
f(3)
class test_greedy_distribute(unittest.TestCase): class test_greedy_distribute(unittest.TestCase):
...@@ -300,8 +297,8 @@ class test_greedy_distribute(unittest.TestCase): ...@@ -300,8 +297,8 @@ class test_greedy_distribute(unittest.TestCase):
eps = scalar('eps') eps = scalar('eps')
s = scalar('s') s = scalar('s')
#r = theano.tensor.mul(theano.tensor.fill(x, 2.*a), x/a , (y+z) , a) # r = theano.tensor.mul(theano.tensor.fill(x, 2.*a), x/a , (y+z) , a)
#r = theano.tensor.mul((x/a+y) , a, z) # r = theano.tensor.mul((x/a+y) , a, z)
r = tensor.mul(s - 1, r = tensor.mul(s - 1,
eps + x / s, eps + x / s,
eps + y / s, eps + y / s,
...@@ -326,16 +323,16 @@ class test_canonize(unittest.TestCase): ...@@ -326,16 +323,16 @@ class test_canonize(unittest.TestCase):
def test_muldiv(self): def test_muldiv(self):
x, y, z = matrices('xyz') x, y, z = matrices('xyz')
a, b, c, d = matrices('abcd') a, b, c, d = matrices('abcd')
# e = (2.0 * x) / (2.0 * y) # e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y) # e = (2.0 * x) / (4.0 * y)
# e = x / (y / z) # e = x / (y / z)
# e = (x * y) / x # e = (x * y) / x
# e = (x / y) * (y / z) * (z / x) # e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d) # e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d) # e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2 # e = 2 * x / 2
# e = x / y / x # e = x / y / x
# e = (x / x) * (y / y) # e = (x / x) * (y / y)
e = (-1 * x) / y / (-2 * z) e = (-1 * x) / y / (-2 * z)
g = FunctionGraph([x, y, z, a, b, c, d], [e]) g = FunctionGraph([x, y, z, a, b, c, d], [e])
print(pprint(g.outputs[0])) print(pprint(g.outputs[0]))
...@@ -355,60 +352,60 @@ class test_canonize(unittest.TestCase): ...@@ -355,60 +352,60 @@ class test_canonize(unittest.TestCase):
shp = (5, 5) shp = (5, 5)
fx, fy, fz = fmatrices('xyz') fx, fy, fz = fmatrices('xyz')
dx, dy, dz = dmatrices('xyz') dx, dy, dz = dmatrices('xyz')
fv = fvector('r').dimshuffle('x', 0) # fv = fvector('r').dimshuffle('x', 0)
dv = dvector('s').dimshuffle('x', 0) # dv = dvector('s').dimshuffle('x', 0)
fxv = theano._asarray(numpy.random.rand(*shp), dtype='float32') fxv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
fyv = theano._asarray(numpy.random.rand(*shp), dtype='float32') fyv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
fzv = theano._asarray(numpy.random.rand(*shp), dtype='float32') fzv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32').reshape(1, shp[0]) # fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32').reshape(1, shp[0])
dxv = theano._asarray(numpy.random.rand(*shp), dtype='float64') # dxv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dyv = theano._asarray(numpy.random.rand(*shp), dtype='float64') # dyv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dzv = theano._asarray(numpy.random.rand(*shp), dtype='float64') # dzv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float64').reshape(1, shp[0]) # dvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float64').reshape(1, shp[0])
cases = [ cases = [
(fx + fy, (fx, fy), (fxv, fyv), 1, 'float32'), (fx + fy, (fx, fy), (fxv, fyv), 1, 'float32'),
(fx * fy, (fx, fy), (fxv, fyv), 1, 'float32'), (fx * fy, (fx, fy), (fxv, fyv), 1, 'float32'),
# (fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'), # (fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
# (dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'), # (dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
# (fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'), # (fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
# (dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'), # (dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
# (fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'), # (fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
# (dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'), # (dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
# (fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type add # (fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'), # check mixed type add
# (dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type mul # (dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'), # check mixed type mul
# check with dimshuffle of constant # check with dimshuffle of constant
(fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1, {'custom': (fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1,
'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
(fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1, {'custom': (fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1,
'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
# (2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'), # (2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
# (2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'), # (2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2 + fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1, { (2 + fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1,
'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
(2 * fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1, { (2 * fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1,
'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
# (fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'), # (fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
# (fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'), # (fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx * fy * 2 * (fx + fy + fz+2), (fx, fy, fz), (fxv, fyv, fzv), 2, { (fx * fy * 2 * (fx + fy + fz + 2), (fx, fy, fz), (fxv, fyv, fzv), 2,
'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
# check with broadcast of row # check with broadcast of row
# (fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'), # (fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
# (fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'), # (fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
# (fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'), # (fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
# (fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'), # (fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
# (fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'), # (fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
# (fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'), # (fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
# (fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'), # (fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
# (dx+dy+dz+dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'), # (dx+dy+dz+dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
# (dx*dy*dz*dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'), # (dx*dy*dz*dv,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
# (dv+dx+dy+dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'), # (dv+dx+dy+dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
# (dv*dx*dy*dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'), # (dv*dx*dy*dz,(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),1,'float64'),
# (dx*dy*dv*(dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'), # (dx*dy*dv*(dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
# (dx*dy*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'), # (dx*dy*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
# (dx*dy*dv*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'), # (dx*dy*dv*(dv+dx+dy+dz),(dx,dy,dz,dv),(dxv,dyv,dzv,dvv),2,'float64'),
] # [10:11] ] # [10:11]
# print cases # print cases
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
...@@ -457,61 +454,38 @@ class test_canonize(unittest.TestCase): ...@@ -457,61 +454,38 @@ class test_canonize(unittest.TestCase):
(dx + dy + dz, (dx, dy, dz), (dxv, dyv, dzv), 1, 'float64'), (dx + dy + dz, (dx, dy, dz), (dxv, dyv, dzv), 1, 'float64'),
(fx * fy * fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'), (fx * fy * fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
(dx * dy * dz, (dx, dy, dz), (dxv, dyv, dzv), 1, 'float64'), (dx * dy * dz, (dx, dy, dz), (dxv, dyv, dzv), 1, 'float64'),
(fx * fy * (fx + fy + fz), (fx, fy, fz), (fxv, fyv, (fx * fy * (fx + fy + fz), (fx, fy, fz), (fxv, fyv, fzv), 2, 'float32'),
fzv), 2, 'float32'), (dx * dy * (dx + dy + dz), (dx, dy, dz), (dxv, dyv, dzv), 2, 'float64'),
(dx * dy * (dx + dy + dz), (dx, dy, dz), (dxv, dyv, (fx * fy * (fx + fy + dz), (fx, fy, dz), (dxv, dyv, dzv), 2, 'float64'), # check mixed type add
dzv), 2, 'float64'), (dz * fy * (fx + fy), (fx, fy, dz), (dxv, dyv, dzv), 2, 'float64'), # check mixed type mul
(fx * fy * (fx + fy + dz), (fx, fy, dz), (dxv, dyv, dzv), 2,
'float64'), # check mixed type add
(dz * fy * (fx + fy), (fx, fy, dz), (dxv, dyv, dzv), 2,
'float64'), # check mixed type mul
# check with dimshuffle of constant # check with dimshuffle of constant
(fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'), (fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
(fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'), (fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
(2 + fx + fy + fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'), (2 + fx + fy + fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
(2 * fx * fy * fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'), (2 * fx * fy * fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
(2 + fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, (2 + fx + fy + fz + 2, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
fzv), 1, 'float32'), (2 * fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
(2 * fx * fy * fz * 2, (fx, fy, fz), (fxv, fyv, (fx * fy * 2 * (fx + fy + fz), (fx, fy, fz), (fxv, fyv, fzv), 2, 'float32'),
fzv), 1, 'float32'), (fx * fy * (2 + fx + fy + fz), (fx, fy, fz), (fxv, fyv, fzv), 2, 'float32'),
(fx * fy * 2 * (fx+fy+fz), (fx, fy, fz), (fxv, fyv, (fx * fy * 2 * (fx + fy + fz + 2), (fx, fy, fz), (fxv, fyv, fzv), 2, 'float32'),
fzv), 2, 'float32'),
(fx*fy*(2+fx+fy+fz), (fx, fy, fz), (fxv, fyv, fzv), 2, 'float32'),
(fx*fy*2*(fx+fy+fz+2), (fx, fy, fz), (fxv, fyv,
fzv), 2, 'float32'),
# check with broadcast of row # check with broadcast of row
(fx+fy+fz+fv, (fx, fy, fz, fv), (fxv, fyv, fzv, (fx + fy + fz + fv, (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 1, 'float32'),
fvv), 1, 'float32'), (fx * fy * fz * fv, (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 1, 'float32'),
(fx*fy*fz*fv, (fx, fy, fz, fv), (fxv, fyv, fzv, (fv + fx + fy + fz, (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 1, 'float32'),
fvv), 1, 'float32'), (fv * fx * fy * fz, (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 1, 'float32'),
(fv+fx+fy+fz, (fx, fy, fz, fv), (fxv, fyv, fzv, (fx * fy * fv * (fx + fy + fz), (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 2, 'float32'),
fvv), 1, 'float32'), (fx * fy * (fv + fx + fy + fz), (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 2, 'float32'),
(fv*fx*fy*fz, (fx, fy, fz, fv), (fxv, fyv, fzv, (fx * fy * fv * (fv + fx + fy + fz), (fx, fy, fz, fv), (fxv, fyv, fzv, fvv), 2, 'float32'),
fvv), 1, 'float32'), (dx + dy + dz + dv, (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 1, 'float64'),
(fx*fy*fv*(fx+fy+fz), (fx, fy, fz, fv), (fxv, fyv, (dx * dy * dz * dv, (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 1, 'float64'),
fzv, fvv), 2, 'float32'), (dv + dx + dy + dz, (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 1, 'float64'),
(fx*fy*(fv+fx+fy+fz), (fx, fy, fz, fv), (fxv, fyv, (dv * dx * dy * dz, (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 1, 'float64'),
fzv, fvv), 2, 'float32'), (dx * dy * dv * (dx + dy + dz), (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 2, 'float64'),
(fx*fy*fv*(fv+fx+fy+fz), (fx, fy, fz, fv), (fxv, fyv, fzv, (dx * dy * (dv + dx + dy + dz), (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 2, 'float64'),
fvv), 2, 'float32'), (dx * dy * dv * (dv + dx + dy + dz), (dx, dy, dz, dv), (dxv, dyv, dzv, dvv), 2, 'float64'),
(dx+dy+dz+dv, (dx, dy, dz, dv), (dxv, dyv, dzv,
dvv), 1, 'float64'),
(dx*dy*dz*dv, (dx, dy, dz, dv), (dxv, dyv, dzv,
dvv), 1, 'float64'),
(dv+dx+dy+dz, (dx, dy, dz, dv), (dxv, dyv, dzv,
dvv), 1, 'float64'),
(dv*dx*dy*dz, (dx, dy, dz, dv), (dxv, dyv, dzv,
dvv), 1, 'float64'),
(dx*dy*dv*(dx+dy+dz), (dx, dy, dz, dv), (dxv, dyv,
dzv, dvv), 2, 'float64'),
(dx*dy*(dv+dx+dy+dz), (dx, dy, dz, dv), (dxv, dyv,
dzv, dvv), 2, 'float64'),
(dx*dy*dv*(dv+dx+dy+dz), (dx, dy, dz, dv), (dxv, dyv, dzv,
dvv), 2, 'float64'),
] # [10:11] ] # [10:11]
# print cases # print cases
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
...@@ -568,11 +542,11 @@ class test_canonize(unittest.TestCase): ...@@ -568,11 +542,11 @@ class test_canonize(unittest.TestCase):
'local_elemwise_fusion') 'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test x / x -> 1 # test x / x -> 1
for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([(fx/fx, [fx], [fxv], 'float32'), for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([
(dx/dx, [dx], [dxv], 'float64'), (fx / fx, [fx], [fxv], 'float32'),
(fv/fv, [fv], [fvv], 'float32'), (dx / dx, [dx], [dxv], 'float64'),
(dv/dv, [dv], [dvv], 'float64'), (fv / fv, [fv], [fvv], 'float32'),
]): (dv / dv, [dv], [dvv], 'float64')]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
...@@ -591,14 +565,14 @@ class test_canonize(unittest.TestCase): ...@@ -591,14 +565,14 @@ class test_canonize(unittest.TestCase):
# test (x * y) / x -> y # test (x * y) / x -> y
for id, (g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate([ for id, (g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate([
((dx*dy)/dx, [dx, dy], [dxv, dyv], 0, 'float64'), ((dx * dy) / dx, [dx, dy], [dxv, dyv], 0, 'float64'),
((fx*fy)/fx, [fx, fy], [fxv, fyv], 0, 'float32'), ((fx * fy) / fx, [fx, fy], [fxv, fyv], 0, 'float32'),
((dv*dy)/dv, [dv, dy], [dvv, dyv], 0, 'float64'), ((dv * dy) / dv, [dv, dy], [dvv, dyv], 0, 'float64'),
((fv*fy)/fv, [fv, fy], [fvv, fyv], 0, 'float32'), ((fv * fy) / fv, [fv, fy], [fvv, fyv], 0, 'float32'),
# must broadcast as their is a dimshuffle in the computation # must broadcast as there is a dimshuffle in the computation
((dx*dv)/dx, [dx, dv], [dxv, dvv], 1, 'float64'), ((dx * dv) / dx, [dx, dv], [dxv, dvv], 1, 'float64'),
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float64, row)>)] # topo: [Elemwise{second,no_inplace}(x, <TensorType(float64, row)>)]
((fx*fv)/fx, [fx, fv], [fxv, fvv], 1, 'float32') ((fx * fv) / fx, [fx, fv], [fxv, fvv], 1, 'float32')
# topo: [Elemwise{second,no_inplace}(x, <TensorType(float32, row)>)] # topo: [Elemwise{second,no_inplace}(x, <TensorType(float32, row)>)]
]): ]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
...@@ -614,19 +588,17 @@ class test_canonize(unittest.TestCase): ...@@ -614,19 +588,17 @@ class test_canonize(unittest.TestCase):
# test x / y / x -> 1 / y # test x / y / x -> 1 / y
for id, (g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate([ for id, (g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate([
((dx/dy)/dx, [dx, dy], [dxv, dyv], 1, 'float64'), ((dx / dy) / dx, [dx, dy], [dxv, dyv], 1, 'float64'),
((fx/fy)/fx, [fx, fy], [fxv, fyv], 1, 'float32'), ((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, 'float32'),
((dv/dy)/dv, [dv, dy], [dvv, dyv], 1, 'float64'), ((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, 'float64'),
((fv/fy)/fv, [fv, fy], [fvv, fyv], 1, 'float32'), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, 'float32'),
# must broadcast as their is a dimshuffle in the computation # must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, 'float64'),
((dx/dv)/dx, [dx, dv], [dxv, dvv], 1, 'float64'), # topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc] ((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, 'float32'),
((fx/fv)/fx, [fx, fv], [fxv, fvv], 1, 'float32'), # topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
# topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
]): ]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g, mode=mode)
mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
utt.assert_allclose(out, (1 / val_inputs[1])) utt.assert_allclose(out, (1 / val_inputs[1]))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -650,58 +622,50 @@ class test_canonize(unittest.TestCase): ...@@ -650,58 +622,50 @@ class test_canonize(unittest.TestCase):
((dx / dy) * (dy / dz) * (dz / dv), [dx, dy, dz, dv], [dxv, dyv, dzv, dvv], 'float64'), ((dx / dy) * (dy / dz) * (dz / dv), [dx, dy, dz, dv], [dxv, dyv, dzv, dvv], 'float64'),
((fx / fy) * (fy / fz) * (fz / fv), [fx, fy, fz, fv], [fxv, fyv, fzv, fvv], 'float32'), ((fx / fy) * (fy / fz) * (fz / fv), [fx, fy, fz, fv], [fxv, fyv, fzv, fvv], 'float32'),
]): ]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g, mode=mode)
mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
utt.assert_allclose(out, (val_inputs[0] / val_inputs[3])) utt.assert_allclose(out, (val_inputs[0] / val_inputs[3]))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, (T.Elemwise, )) assert isinstance(topo[0].op, (T.Elemwise, ))
assert isinstance(topo[0].op.scalar_op, assert isinstance(topo[0].op.scalar_op, theano.scalar.basic.TrueDiv)
theano.scalar.basic.TrueDiv)
assert len(topo[0].inputs) == 2 assert len(topo[0].inputs) == 2
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
# test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y # test (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([ for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([
(((2.0*dx)/(4.0*dy)), [dx, dy], [dxv, dyv], 'float64'), (((2.0 * dx) / (4.0 * dy)), [dx, dy], [dxv, dyv], 'float64'),
(((2.0*fx)/(4.0*fy)), [fx, fy], [fxv, fyv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), (((2.0 * fx) / (4.0 * fy)), [fx, fy], [fxv, fyv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
(((2.0*dv)/(4.0*dy)), [dv, dy], [dvv, dyv], 'float64'), (((2.0 * dv) / (4.0 * dy)), [dv, dy], [dvv, dyv], 'float64'),
(((2.0*fv)/(4.0*fy)), [fv, fy], [fvv, fyv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), (((2.0 * fv) / (4.0 * fy)), [fv, fy], [fvv, fyv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
(((2.0*dx)/(4.0*dv)), [dx, dv], [dxv, dvv], 'float64'), (((2.0 * dx) / (4.0 * dv)), [dx, dv], [dxv, dvv], 'float64'),
(((2.0*fx)/(4.0*fv)), [fx, fv], [fxv, fvv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), (((2.0 * fx) / (4.0 * fv)), [fx, fv], [fxv, fvv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
]): ]):
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g, mode=mode)
mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
utt.assert_allclose(out, (0.5 * utt.assert_allclose(out, (0.5 * val_inputs[0] / val_inputs[1]))
val_inputs[0] / val_inputs[1]))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
assert isinstance(topo[0].op, (T.Elemwise, )) assert isinstance(topo[0].op, (T.Elemwise, ))
assert isinstance(topo[0].op.scalar_op, assert isinstance(topo[0].op.scalar_op, theano.scalar.basic.Mul)
theano.scalar.basic.Mul)
assert len(topo[0].inputs) == 2 assert len(topo[0].inputs) == 2
assert isinstance(topo[1].op, (T.Elemwise, )) assert isinstance(topo[1].op, (T.Elemwise, ))
assert isinstance(topo[1].op.scalar_op, assert isinstance(topo[1].op.scalar_op, theano.scalar.basic.TrueDiv)
theano.scalar.basic.TrueDiv)
assert len(topo[1].inputs) == 2 assert len(topo[1].inputs) == 2
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
# test 2 * x / 2 -> x # test 2 * x / 2 -> x
for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([ for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([
((2*dx)/2, [dx], [dxv], 'float64'), ((2 * dx) / 2, [dx], [dxv], 'float64'),
((2*fx)/2, [fx], [fxv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), ((2 * fx) / 2, [fx], [fxv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
((2*dv)/2, [dv], [dvv], 'float64'), ((2 * dv) / 2, [dv], [dvv], 'float64'),
((2*fv)/2, [fv], [fvv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}), ((2 * fv) / 2, [fv], [fvv], {'custom': 'float32', 'numpy+floatX': config.floatX, 'numpy': 'float64'}),
]): ]):
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g, mode=mode)
mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
utt.assert_allclose(out, val_inputs[0]) utt.assert_allclose(out, val_inputs[0])
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -711,15 +675,14 @@ class test_canonize(unittest.TestCase): ...@@ -711,15 +675,14 @@ class test_canonize(unittest.TestCase):
# test x / abs(x) -> sign(x) # test x / abs(x) -> sign(x)
for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([ for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([
(dx/abs(dx), [dx], [0.5-dxv], 'float64'), (dx / abs(dx), [dx], [0.5 - dxv], 'float64'),
(fx/abs(fx), [fx], [0.5-fxv], 'float32'), (fx / abs(fx), [fx], [0.5 - fxv], 'float32'),
(dx/abs(dx), [dx], [0.1*dxv], 'float64'), (dx / abs(dx), [dx], [0.1 * dxv], 'float64'),
(fx/abs(fx), [fx], [0.1*fxv], 'float32'), (fx / abs(fx), [fx], [0.1 * fxv], 'float32'),
(dv/abs(dv), [dv], [0.5-dvv], 'float64'), (dv / abs(dv), [dv], [0.5 - dvv], 'float64'),
(fv/abs(fv), [fv], [0.5-fvv], 'float32'), (fv / abs(fv), [fv], [0.5 - fvv], 'float32'),
]): ]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g, mode=mode)
mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
assert numpy.all(numpy.isfinite(out)) assert numpy.all(numpy.isfinite(out))
utt.assert_allclose(out, numpy.sign(val_inputs[0])) utt.assert_allclose(out, numpy.sign(val_inputs[0]))
...@@ -756,7 +719,7 @@ class test_canonize(unittest.TestCase): ...@@ -756,7 +719,7 @@ class test_canonize(unittest.TestCase):
""" """
x = T.dscalar() x = T.dscalar()
a = T.abs_(x) # a = T.abs_(x)
if theano.config.mode == 'FAST_COMPILE': if theano.config.mode == 'FAST_COMPILE':
mode = theano.compile.mode.get_mode('FAST_RUN').excluding( mode = theano.compile.mode.get_mode('FAST_RUN').excluding(
...@@ -804,49 +767,43 @@ class test_canonize(unittest.TestCase): ...@@ -804,49 +767,43 @@ class test_canonize(unittest.TestCase):
dxv = theano._asarray(numpy.random.rand(*shp), dtype='float32') dxv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
dyv = theano._asarray(numpy.random.rand(*shp), dtype='float32') dyv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
dzv = theano._asarray(numpy.random.rand(*shp), dtype='float32') dzv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32').reshape(1, shp[0]) # fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32').reshape(1, shp[0])
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
opt = gof.Query(["canonicalize"]) opt = gof.Query(["canonicalize"])
opt = opt.excluding( opt = opt.excluding('local_elemwise_fusion')
'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt) mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail! # test fail!
# test x / y / z -> x / (y * z) # test x / y / z -> x / (y * z)
for (g, sym_inputs, val_inputs, out_dtype) in [ for (g, sym_inputs, val_inputs, out_dtype) in [
((dx/dy)/dz, [dx, dy, dz], [dxv, dyv, dzv], 'float64'), ((dx / dy) / dz, [dx, dy, dz], [dxv, dyv, dzv], 'float64'),
((fx/fy)/fz, [fx, fy, fz], [fxv, fyv, fzv], 'float32') ((fx / fy) / fz, [fx, fy, fz], [fxv, fyv, fzv], 'float32')
]: ]:
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g, mode=mode)
mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
utt.assert_allclose(out, val_inputs[0] / utt.assert_allclose(out, val_inputs[0] / val_inputs[1] / val_inputs[2])
val_inputs[1] / val_inputs[2])
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
assert isinstance(topo[0].op, (T.Elemwise, )) assert isinstance(topo[0].op, (T.Elemwise, ))
assert isinstance(topo[0].op.scalar_op, assert isinstance(topo[0].op.scalar_op, theano.scalar.basic.Inv)
theano.scalar.basic.Inv)
assert len(topo[0].inputs) == 1 assert len(topo[0].inputs) == 1
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
# test x / (y / z) -> (x * z) / y # test x / (y / z) -> (x * z) / y
for (g, sym_inputs, val_inputs, out_dtype) in [ for (g, sym_inputs, val_inputs, out_dtype) in [
(dx/(dy/dz), [dx, dy, dz], [dxv, dyv, dzv], 'float64'), (dx / (dy / dz), [dx, dy, dz], [dxv, dyv, dzv], 'float64'),
(fx/(fy/fz), [fx, fy, fz], [fxv, fyv, fzv], 'float32') (fx / (fy / fz), [fx, fy, fz], [fxv, fyv, fzv], 'float32')
]: ]:
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
out = f(*val_inputs) out = f(*val_inputs)
utt.assert_allclose(out, val_inputs[0] / ( utt.assert_allclose(out, val_inputs[0] / (val_inputs[1] / val_inputs[2]))
val_inputs[1] / val_inputs[2]))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
assert isinstance(topo[0].op, (T.Elemwise, )) assert isinstance(topo[0].op, (T.Elemwise, ))
assert isinstance(topo[0].op.scalar_op, assert isinstance(topo[0].op.scalar_op, theano.scalar.basic.Inv)
theano.scalar.basic.Inv)
assert len(topo[0].inputs) == 1 assert len(topo[0].inputs) == 1
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
...@@ -868,7 +825,7 @@ class test_canonize(unittest.TestCase): ...@@ -868,7 +825,7 @@ class test_canonize(unittest.TestCase):
logging.getLogger('theano.gof.opt').addHandler(handler) logging.getLogger('theano.gof.opt').addHandler(handler)
try: try:
x = vector() x = vector()
f = theano.function([x], x + numpy.nan) theano.function([x], x + numpy.nan)
finally: finally:
logging.getLogger('theano.gof.opt').removeHandler(handler) logging.getLogger('theano.gof.opt').removeHandler(handler)
# Ideally this test would only catch the maxed out equilibrium # Ideally this test would only catch the maxed out equilibrium
...@@ -960,7 +917,6 @@ class test_fusion(unittest.TestCase): ...@@ -960,7 +917,6 @@ class test_fusion(unittest.TestCase):
""" """
# TODO: disable the canonizer? # TODO: disable the canonizer?
def my_init(shp, dtype='float64', num=0): def my_init(shp, dtype='float64', num=0):
#ret = theano._asarray(numpy.random.rand(*shp),dtype=dtype)
ret = numpy.zeros(shp, dtype=dtype) + num ret = numpy.zeros(shp, dtype=dtype) + num
return ret return ret
fw, fx, fy, fz = [theano.tensor.tensor(dtype='float32', fw, fx, fy, fz = [theano.tensor.tensor(dtype='float32',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论