提交 be95a5ce authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #684 from nouiz/blas

Improve gemm optimizer to never accidentally duplicate work.
......@@ -33,7 +33,8 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError
from type import \
Type, Generic, generic
......
......@@ -9,6 +9,15 @@ class AlreadyThere(Exception):
pass
class ReplacementDidntRemovedError(Exception):
"""This exception should be thrown by replace_all_validate_remove
when an optimization wanted to remove a Variable or a Node from
the graph, but the replacement it gived didn't do that.
"""
pass
class Bookkeeper:
def on_attach(self, env):
......@@ -91,12 +100,15 @@ class ReplaceValidate(History, Validator):
" or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
env.replace_all_validate_remove = partial(
self.replace_all_validate_remove, env)
def on_detach(self, env):
History.on_detach(self, env)
Validator.on_detach(self, env)
del env.replace_validate
del env.replace_all_validate
del env.replace_all_validate_remove
def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)], reason=reason)
......@@ -121,6 +133,28 @@ class ReplaceValidate(History, Validator):
except Exception, e:
env.revert(chk)
raise
return chk
def replace_all_validate_remove(self, env, replacements,
remove, reason=None):
"""As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. It also print a warning.
"""
chk = env.replace_all_validate(replacements, reason)
for rm in remove:
if rm in env.nodes or rm in env.variables:
env.revert(chk)
out = sys.stderr
print >> out, (
"WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization."
" Your function runs correctly, but it would be"
" appreciated if you submit this problem to the mailing"
" list theano-users so that we can fix it.")
print >> out, reason, replacements
raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper):
......
......@@ -133,8 +133,10 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, DestroyHandler,
local_optimizer, Optimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer, Apply)
local_optimizer, Optimizer,
InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
from theano.gof.python25 import all, any
......@@ -1022,7 +1024,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M
return rval
return rval, M
# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
......@@ -1035,7 +1037,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle(0, 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)]
return rval
return rval, MM
if tuple(M.owner.op.new_order) == (1,):
# it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22:
......@@ -1043,7 +1045,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 0),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)]
return rval
return rval, MM
if tuple(M.owner.op.new_order) == ():
# it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22:
......@@ -1051,7 +1053,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 'x'),
alpha, MMl, MMr, beta)
rval = [g.dimshuffle()]
return rval
return rval, MM
# this is False'd out because of inadequate testing.
# TODO see ticket #237
......@@ -1085,7 +1087,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
if recurse_flip:
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False)
else:
return False
return False, False
def _gemm_canonicalize(r, scale, rval, maxclients):
......@@ -1250,7 +1252,8 @@ def _gemm_from_factored_list(lst):
#print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j)
gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(s_i, M_i,
s_j, M_j)
#print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list:
def item_to_var(t):
......@@ -1273,7 +1276,7 @@ def _gemm_from_factored_list(lst):
else:
rval = add_inputs
#print "RETURNING GEMM THIGN", rval
return rval
return rval, old_dot22
def _gemm_from_node2(node):
......@@ -1301,7 +1304,7 @@ def _gemm_from_node2(node):
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket.
if rval and (rval[0].type == node.outputs[0].type):
if rval and (rval[0][0].type == node.outputs[0].type):
return rval
......@@ -1326,11 +1329,13 @@ class GemmOptimizer(Optimizer):
except InconsistencyError, e:
continue
if new_outputs:
new_outputs, old_dot22 = new_outputs
assert len(new_outputs) == len(node.outputs)
try:
env.replace_all_validate(
zip(node.outputs, new_outputs),
reason='GemmOptimizer'
env.replace_all_validate_remove(
zip(node.outputs, new_outputs),
[old_dot22],
reason='GemmOptimizer'
)
did_something = True
break
......@@ -1338,6 +1343,8 @@ class GemmOptimizer(Optimizer):
# TODO: retry other applications of gemm (see comment
# in _gemm_from_node)
pass
except ReplacementDidntRemovedError, e:
pass
class Dot22(GemmRelated):
......
......@@ -15,7 +15,6 @@ from numpy.testing import assert_array_almost_equal
#from numpy.testing import dec
#from numpy.testing.noseclasses import KnownFailureTest
#from theano.tensor.blas import *
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
_is_real_matrix, _gemm_canonicalize,
_factor_canonicalized, Gemm, Gemv,
......@@ -46,6 +45,10 @@ def test_dot_eq():
assert T.Dot() == T.Dot()
def sharedX(x, name):
return theano.shared(numpy.asarray(x, config.floatX), name=name)
class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it
is supposed to.
......@@ -171,20 +174,22 @@ class t_gemm(TestCase):
self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0)
def test_factorised_scalar(self):
a = T.dmatrix()
b = T.dmatrix()
c = T.dmatrix()
s = theano.shared(numpy.zeros((5, 5)))
a = T.matrix()
b = T.matrix()
c = T.matrix()
s = theano.shared(numpy.zeros((5, 5)).astype(config.floatX))
lr1 = T.constant(0.01).astype('float64')
lr2 = T.constant(2).astype('float64')
l2_reg = T.constant(0.0001).astype('float64')
lr1 = T.constant(0.01).astype(config.floatX)
lr2 = T.constant(2).astype(config.floatX)
l2_reg = T.constant(0.0001).astype(config.floatX)
#test constant merge with gemm
f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) +
l2_reg * lr2 * s},
mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)]
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
# <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
# 2e-06)]
assert len(f) == 1
assert f[0].op == gemm_inplace
......@@ -192,14 +197,19 @@ class t_gemm(TestCase):
f = theano.function([a, b], updates={s: lr1 * (T.dot(a, b) -
l2_reg * s)},
mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, -2e-06)]
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
# <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
# -2e-06)]
assert len(f) == 1
assert f[0].op == gemm_inplace
#test factored scalar with merge and neg
f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))},
f = theano.function([a, b],
updates={s: s - lr1 * (s * .0002 + T.dot(a, b))},
mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 0.999998)]
#[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01,
# <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
# 0.999998)]
assert len(f) == 1
assert f[0].op == gemm_inplace
......@@ -291,7 +301,8 @@ class t_gemm(TestCase):
tx.set_value(y_T, borrow=True)
f()
# test that the transposed version of multiplication gives same answer
# test that the transposed version of multiplication gives
# same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T))
t(C, A, B)
......@@ -330,12 +341,14 @@ class t_gemm(TestCase):
z_orig = z.copy()
z_after = numpy.zeros_like(z_orig)
for i in xrange(3):
z_after[:,:,i] = self._gemm(z[:,:,i], a, x[:,:,i], y[:,:,i], b)
z_after[:, :, i] = self._gemm(z[:, :, i], a,
x[:, :, i], y[:, :, i], b)
tz, ta, tx, ty, tb = [shared(p) for p in z, a, x, y, b]
for i in xrange(3):
f_i = inplace_func([],
gemm_inplace(tz[:,:,i], ta, tx[:,:,i], ty[:,:,i], tb),
gemm_inplace(tz[:, :, i],
ta, tx[:, :, i], ty[:, :, i], tb),
mode=compile.Mode(optimizer=None, linker=l))
for j in xrange(3):
# tz will not _always_ be overwritten,
......@@ -347,30 +360,32 @@ class t_gemm(TestCase):
self.assertTrue(
_approx_eq(z_after[:, :, i],
tz.get_value(borrow=True)[:,:,i]),
(z_orig[:,:,i], z_after[:,:,i],
z[:,:,i], z_after[:,:,i] - z[:,:,i]))
tz.get_value(borrow=True)[:, :, i]),
(z_orig[:, :, i], z_after[:, :, i],
z[:, :, i], z_after[:, :, i] - z[:, :, i]))
tz_i = gemm_no_inplace(tz[:,:,i], ta, tx[:,:,i], ty[:,:,i], tb)
tz_i = gemm_no_inplace(tz[:, :, i], ta, tx[
:, :, i], ty[:, :, i], tb)
g_i = theano.function([], tz_i,
updates={tz:T.set_subtensor(tz[:,:,i], tz_i)},
updates={tz: T.set_subtensor(tz[:, :, i], tz_i)},
mode=compile.Mode(optimizer=None, linker=l))
for j in xrange(3):
g_i()
self.assertTrue(
_approx_eq(z_after[:,:,i],
tz.get_value(borrow=True)[:,:,i]),
(z_orig[:,:,i], z_after[:,:,i],
z[:,:,i], z_after[:,:,i] - z[:,:,i]))
_approx_eq(z_after[:, :, i],
tz.get_value(borrow=True)[:, :, i]),
(z_orig[:, :, i], z_after[:, :, i],
z[:, :, i], z_after[:, :, i] - z[:, :, i]))
t(C, A, B)
t(C.transpose((1,0,2)), A, B)
t(C, A.transpose((1,0,2)), B, dt='float32')
t(C, A, B.transpose((1,0,2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B)
t(C, A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32')
t(C.transpose((1,0,2)), A, B.transpose((1,0,2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32')
t(C.transpose((1, 0, 2)), A, B)
t(C, A.transpose((1, 0, 2)), B, dt='float32')
t(C, A, B.transpose((1, 0, 2)))
t(C.transpose((1, 0, 2)), A.transpose((1, 0, 2)), B)
t(C, A.transpose((1, 0, 2)), B.transpose((1, 0, 2)), dt='float32')
t(C.transpose((1, 0, 2)), A, B.transpose((1, 0, 2)))
t(C.transpose((1, 0, 2)), A.transpose((1, 0, 2)), B.transpose((
1, 0, 2)), dt='float32')
def test_res_is_a():
......@@ -418,7 +433,7 @@ class t_as_scalar(TestCase):
def test3(self):
"""Test that it fails on nonscalar variables"""
a = T.dmatrix()
a = T.matrix()
self.assertTrue(None == _as_scalar(a))
self.assertTrue(None == _as_scalar(T.DimShuffle([False, False],
[0, 'x', 1])(a)))
......@@ -427,7 +442,7 @@ class t_as_scalar(TestCase):
class T_real_matrix(TestCase):
def test0(self):
self.assertTrue(_is_real_matrix(T.DimShuffle([False, False],
[1, 0])(T.dmatrix())))
[1, 0])(T.matrix())))
self.assertTrue(not _is_real_matrix(T.DimShuffle([False],
['x', 0])
(T.dvector())))
......@@ -438,32 +453,38 @@ def fail(msg):
assert False
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
"""This test suite ensures that Gemm is inserted where it belongs, and
that the resulting functions compute the same things as the
originals.
"""
def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
return T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
class Failure(Exception):
pass
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
max_graphlen=0, expected_nb_gemm=1):
try:
f = inplace_func(
[Param(ii, mutable=True, allow_downcast=True) for ii in i],
o,
mode='FAST_RUN',
on_unused_input='ignore')
at_least_one_gemm = False
nb_gemm = 0
for node in f.maker.env.nodes:
if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace:
at_least_one_gemm = True
assert at_least_one_gemm
nb_gemm += 1
assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.env.nodes:
......@@ -476,11 +497,16 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes])
r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure:
......@@ -491,62 +517,73 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = XYZab()
X, Y, Z, a, b = XYZab()
just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) + b * Z])
just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a - Z * b])
just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) - b * Z])
just_gemm([X, Y, Z, a, b], [b * Z - a * T.dot(X, Y)])
#with transposes (transposes should be pushed through dot in canonicalize)
just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y).T],
ishapes=[(5,3), (3,4), (4,5), (), ()])
just_gemm([X, Y, Z, a, b], [b * Z.T - a * T.dot(Y.T, X.T)])
just_gemm([X, Y, Z, a, b], [b * Z.T + a * b * T.dot(X, Y).T])
just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y).T],
ishapes=[(5, 3), (3, 4), (4, 5), (), ()])
#with N multiplications instead of just one
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z * b + T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z + a * b * a * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [(b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
just_gemm([X, Y, Z, a, b], [Z - T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z * b - T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c]
o = [(a * T.dot(X,Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))]
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.matrix(), T.matrix(), T.scalar()
just_gemm([X, Y, Z, a, b, R, S, c],
[Z * c + a * T.dot(X, Y) + b * T.dot(R, S).T],
ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()],
expected_nb_gemm=2)
ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()]
i = [X, Y, Z, a, b, R, S, c]
o = [(a * T.dot(X, Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o,
f = inplace_func([Param(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Failure('dot in graph')
if node.op == _dot22: raise Failure('_dot22 in graph')
if node.op == T.dot:
raise Failure('dot in graph')
if node.op == _dot22:
raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore')
#for node in g.maker.env.nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes])
r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure(
'GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure:
for node in f.maker.env.toposort():
print 'GRAPH', node
......@@ -554,8 +591,10 @@ def test_gemm_opt_double_gemm():
def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
'a'), T.scalar('b')
R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
'c'), T.scalar('d')
u = T.row('u')
v = T.vector('v')
w = T.col('w')
......@@ -584,7 +623,7 @@ def test_gemm_canonicalize():
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = []
_gemm_canonicalize(a*X + Y - b*Z*c, 1.0, can, 0)
_gemm_canonicalize(a * X + Y - b * Z * c, 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == T.mul
......@@ -593,7 +632,7 @@ def test_gemm_canonicalize():
assert can[2][0].owner.inputs[1] == b
can = []
_gemm_canonicalize((-d) * X - (a*X + Y - b*Z*c), 1.0, can, 0)
_gemm_canonicalize((-d) * X - (a * X + Y - b * Z * c), 1.0, can, 0)
#print can
assert can[0][0].owner.op == T.neg
assert can[0][0].owner.inputs[0] == d
......@@ -602,14 +641,18 @@ def test_gemm_canonicalize():
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == T.mul
assert can[3][0].owner.inputs == [c,b]
assert can[3][0].owner.inputs == [c, b]
def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
'a'), T.scalar('b')
R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
'c'), T.scalar('d')
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
def test_upcasting_scalar_nogemm():
# Test that the optimization does not crash when the scale has an incorrect
......@@ -643,119 +686,183 @@ def test_upcasting_scalar_nogemm():
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
#theano.printing.debugprint(f, print_type=True)
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
def test_gemm_nested():
X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
'a'), T.scalar('b')
R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
'c'), T.scalar('d')
just_gemm([X, Y, Z, R, S, U, a, b, c, d],
[a * Z - b * (c * T.dot(X, Y) + d * Z)],
ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4), (
2, 4), (), (), (), ()],
max_graphlen=1)
#print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
just_gemm([X, Y, Z, R, S, U, a, b, c, d],
[a * Z - b * (c * T.dot(X, Y) + d * Z + c * Z)],
ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4), (
2, 4), (), (), (), ()],
max_graphlen=1)
#print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*U)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
just_gemm([X, Y, Z, R, S, U, a, b, c, d],
[a * Z - b * (c * T.dot(X, Y) + d * Z + c * U)],
ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4), (
2, 4), (), (), (), ()],
max_graphlen=3)
def test_gemm_opt_wishlist():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
#with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
just_gemm([X, Y, Z, a, b],
[(b * b) * Z * a + (a * a) * T.dot(X, Y) + b * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y) + T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated.
This adds a vector two the previous test, which triggers the long-sought GEMM bug.
"""Many subgraphs whose dots can be eliminated. This adds a
vector two the previous test, which triggers the long-sought GEMM
bug.
"""
X,Y,Z,a,b = XYZab()
X, Y, Z, a, b = XYZab()
v = T.vector()
def my_just_gemm(o):
i = [X,Y,Z,a,b,v]
ishapes = [(4,3), (3,5), (4,5), (), (), (5,)]
i = [X, Y, Z, a, b, v]
ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, )]
rval = just_gemm(i, o, ishapes=ishapes)
my_just_gemm([v + T.dot(X,Y) * a + Z * b])
my_just_gemm([v + a * T.dot(X,Y) + b * Z])
my_just_gemm([v + b * Z + a * T.dot(X,Y)])
my_just_gemm([v + T.dot(X,Y) * a - Z * b])
my_just_gemm([v + a * T.dot(X,Y) - b * Z])
my_just_gemm([v + b * Z - a * T.dot(X,Y)])
my_just_gemm([v + T.dot(X, Y) * a + Z * b])
my_just_gemm([v + a * T.dot(X, Y) + b * Z])
my_just_gemm([v + b * Z + a * T.dot(X, Y)])
my_just_gemm([v + T.dot(X, Y) * a - Z * b])
my_just_gemm([v + a * T.dot(X, Y) - b * Z])
my_just_gemm([v + b * Z - a * T.dot(X, Y)])
#with N multiplications instead of just one
my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
my_just_gemm([v + Z + T.dot(X,Y)])
my_just_gemm([v + Z*b + T.dot(X,Y)])
my_just_gemm([v + Z + a*b*a*T.dot(X,Y)])
my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
my_just_gemm([Z - T.dot(X,Y) + v])
my_just_gemm([Z*b - T.dot(X,Y) + v])
my_just_gemm([Z - a*b*a*T.dot(X,Y) + v])
my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
my_just_gemm([v + Z + T.dot(X, Y)])
my_just_gemm([v + Z * b + T.dot(X, Y)])
my_just_gemm([v + Z + a * b * a * T.dot(X, Y)])
my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
my_just_gemm([Z - T.dot(X, Y) + v])
my_just_gemm([Z * b - T.dot(X, Y) + v])
my_just_gemm([Z - a * b * a * T.dot(X, Y) + v])
def test_gemm_opt_vector_stuff():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector()
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
u, v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.env.nodes]:
raise Failure('gemm_inplace in graph')
f = inplace_func([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN')
f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm_inplace in graph')
def test_gemm_unrolled():
"""This test that the gemm optimizer remove the dot22 that was
present in the graph. Otherwise, this add a gemm, but still
compute the dot22.
This was not always the case in the with this the following code.
"""
batch_size = 100
rep_size = 40
rng = numpy.random.RandomState([1, 2, 3])
for num_rounds in range(1, 10):
W = sharedX(rng.randn(rep_size, rep_size), name='W')
V = sharedX(numpy.zeros((batch_size, rep_size)), name='V')
H = sharedX(numpy.zeros((batch_size, rep_size)), name='H')
G = sharedX(numpy.zeros((batch_size, rep_size)), name='G')
init_V = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_V')
init_H = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_H')
cur_V = V
cur_H = H
def update_V(cur_H):
return T.nnet.sigmoid(T.dot(cur_H, W.T))
def update_H(cur_V):
return T.nnet.sigmoid(T.dot(cur_V, W) + T.dot(G, W.T))
for i in xrange(num_rounds):
cur_V = update_V(cur_H)
cur_H = update_H(cur_V)
unrolled_theano = theano.function([], updates={V: cur_V, H: cur_H},
name='unrolled_theano')
nb_dot = sum([1 for node in unrolled_theano.maker.env.toposort()
if isinstance(node.op, (theano.tensor.Dot,
theano.tensor.blas.Dot22,
theano.tensor.blas.Gemm))])
# Each num_rounds add 3 dot, but one of them is always the same.
# So the final graph should have 1 + 2* num_rounds dot varient op.
assert nb_dot == num_rounds * 2 + 1, nb_dot
unrolled_theano()
def test_inplace0():
#should fail to insert gemm_inplace because gemm_inplace would create cycles
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R, S, c = T.dmatrix('R'), T.dmatrix('S'), T.dscalar('c')
#should fail to insert gemm_inplace because gemm_inplace would
#create cycles
X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
'a'), T.scalar('b')
R, S, c = T.matrix('R'), T.matrix('S'), T.scalar('c')
f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN')
[Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0])
raise Failure('gemm_inplace in graph')
assert gemm_no_inplace in [n.op for n in f.maker.env.nodes]
# gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)],
f = inplace_func([X, Y, Z, a, b, R, S, c],
[Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph')
def test_inplace1():
X,Y,Z,a,b = XYZab()
X, Y, Z, a, b = XYZab()
# with > 2 terms in the overall addition
f = inplace_func([X, Y, Z],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN')
[Z + Z + T.dot(X, Y)], mode='FAST_RUN')
#theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22():
for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
a = T.matrix(dtype=dtype1)
for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
b = T.matrix(dtype=dtype2)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt)
f = theano.function([a, b], T.dot(a, b), mode=mode_blas_opt)
topo = f.maker.env.toposort()
if dtype1 == dtype2:
assert _dot22 in [x.op for x in topo], (dtype1,dtype2)
assert _dot22 in [x.op for x in topo], (dtype1, dtype2)
else:
assert T.dot in [x.op for x in topo], (dtype1,dtype2)
assert T.dot in [x.op for x in topo], (dtype1, dtype2)
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def cmp(a_shp, b_shp):
av=rng.uniform(size=a_shp).astype(dtype1)
bv=rng.uniform(size=b_shp).astype(dtype2)
f(av,bv)
av = rng.uniform(size=a_shp).astype(dtype1)
bv = rng.uniform(size=b_shp).astype(dtype2)
f(av, bv)
cmp((3, 4), (4, 5))
cmp((0, 4), (4, 5))
......@@ -764,11 +871,13 @@ def test_dot22():
cmp((0, 4), (4, 0))
cmp((0, 0), (0, 0))
def test_dot22scalar():
## including does not seem to work for 'local_dot_to_dot22' and
## 'local_dot22_to_dot22scalar'
## TODO: exclude other optimizations in BlasOpt?
#m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize')
#m = theano.compile.get_default_mode().including('local_dot_to_dot22',
# 'local_dot22_to_dot22scalar','specialize')
#m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
for dtype1 in ['complex64', 'complex128']:
......@@ -784,88 +893,111 @@ def test_dot22scalar():
def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.env.toposort()
ops = [x.op for x in topo]
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1, dtype2)
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1,
dtype2)
if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar>0:
if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2, dtype3, dtype4)
assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2, dtype3, dtype4)
assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
else:
# Currently there is a problem of optimization order
# The constant get upcasted to float64 before we try to merge it
# with the dot22 of float32. So this prevent the merge.
assert _dot22scalar in ops or _dot22 in ops, (dtype1, dtype2, dtype3, dtype4)
# Currently there is a problem of
# optimization order The constant get
# upcasted to float64 before we try to
# merge it with the dot22 of
# float32. So this prevent the merge.
assert _dot22scalar in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2:
assert _dot22 in ops, (dtype1, dtype2, dtype3, dtype4)
assert _dot22 in ops, (dtype1, dtype2,
dtype3, dtype4)
else:
assert T.dot in ops, (dtype1, dtype2, dtype3, dtype4)
assert T.dot in ops, (dtype1, dtype2,
dtype3, dtype4)
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5,5)):
av=rng.uniform(size=a_shp).astype(dtype1)
bv=rng.uniform(size=b_shp).astype(dtype2)
cv=rng.uniform(size=c_shp).astype(dtype3)
sv=rng.uniform(size=sqr_shp).astype(dtype1)
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5, 5)):
av = rng.uniform(size=a_shp).astype(dtype1)
bv = rng.uniform(size=b_shp).astype(dtype2)
cv = rng.uniform(size=c_shp).astype(dtype3)
sv = rng.uniform(size=sqr_shp).astype(dtype1)
if False:
f = theano.function([a,b],cst*T.dot(a,b),mode=mode_blas_opt)
f = theano.function([a, b], cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort()
check_dot22scalar(f, 1)
f(av,bv)
f(av, bv)
if True:
f = theano.function([a,b,c],cst*c*T.dot(a,b),mode=mode_blas_opt)
f = theano.function([a, b, c],
cst * c * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(av,bv,cv)
f(av, bv, cv)
f = theano.function([a,b,c],c * cst*T.dot(a,b),mode=mode_blas_opt)
f = theano.function([a, b, c],
c * cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(av,bv,cv)
f(av, bv, cv)
## Here, canonicalize also seems needed
## TODO: add only the optimizations needed?
m2 = mode_blas_opt.including('canonicalize')
f = theano.function([a,b,c],cst2 *c * cst*T.dot(a,b),mode=m2)
f = theano.function([a, b, c],
cst2 * c * cst * T.dot(a, b),
mode=m2)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(av,bv,cv)
f(av, bv, cv)
if dtype1 == dtype2 == dtype3:
f = theano.function([a,b,c],c * cst*a*T.dot(a,b),mode=m2)
f = theano.function([a, b, c],
c * cst * a * T.dot(a, b),
mode=m2)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(sv,sv,sv)
f(sv, sv, sv)
f = theano.function([a,b,c],cst*c *a*T.dot(a,b),mode=mode_blas_opt)
f = theano.function([a, b, c],
cst * c * a * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together...
# dot22scalar optimizer does not do a recursive search
# therefore, it doesn't find potential matches of the scalar.
# TODO: combine with the 'canonicalization' that is part of the Gemm optimizer.
#currently the canonizer don't always
# merge all Mul together... dot22scalar
# optimizer does not do a recursive search
# therefore, it doesn't find potential
# matches of the scalar. TODO: combine
# with the 'canonicalization' that is part
# of the Gemm optimizer.
#
# assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2
f(sv,sv,sv)
f(sv, sv, sv)
f = theano.function([a,b,c],c * a*cst*T.dot(a,b),mode=m2)
f = theano.function([a, b, c],
c * a * cst * T.dot(a, b),
mode=m2)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(sv,sv,sv)
f(sv, sv, sv)
cmp((3,4),(4,5),(3,5))
cmp((0,4),(4,5),(0,5))
cmp((3,0),(0,5),(3,5))
cmp((3,4),(4,0),(3,0),(0,0))
cmp((0,4),(4,0),(0,0))
cmp((0,0),(0,0),(0,0))
cmp((3, 4), (4, 5), (3, 5))
cmp((0, 4), (4, 5), (0, 5))
cmp((3, 0), (0, 5), (3, 5))
cmp((3, 4), (4, 0), (3, 0), (0, 0))
cmp((0, 4), (4, 0), (0, 0))
cmp((0, 0), (0, 0), (0, 0))
def test_dot22scalar_cast():
......@@ -889,19 +1021,20 @@ def test_dot22scalar_cast():
def test_dot_w_self():
# This can trigger problems in the optimization because what would normally be a gemm must
# not be because the output is aliased to one of the inputs.
# This can trigger problems in the optimization because what would
# normally be a gemm must not be because the output is aliased to
# one of the inputs.
A = shared(value=numpy.ones((2,2)))
A = shared(value=numpy.ones((2, 2)))
B = T.matrix()
p = T.dot(A,A)*B
p = T.dot(A, A) * B
grad = T.grad(T.mean(p), A)
f = theano.function([B], p, updates={A : A - grad})
f = theano.function([B], p, updates={A: A - grad})
# tests correctness in debugmode
f(numpy.asarray([[0,1], [2,3]], dtype=config.floatX))
f(numpy.asarray([[0, 1], [2, 3]], dtype=config.floatX))
###############################################################################
......@@ -927,8 +1060,9 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
''' Test vector dot matrix '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=(2,3)), dtype='float32'))
f = theano.function([], theano.dot(v,m), mode=mode_blas_opt)
m = theano.shared(numpy.array(rng.uniform(size=(2, 3)),
dtype='float32'))
f = theano.function([], theano.dot(v, m), mode=mode_blas_opt)
# Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot)
......@@ -942,14 +1076,13 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
borrow=True)
assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value()))
def test_dot_mv(self):
''' Test matrix dot vector '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=(3,2)),
m = theano.shared(numpy.array(rng.uniform(size=(3, 2)),
dtype='float32'))
f = theano.function([], theano.dot(m,v), mode=mode_blas_opt)
f = theano.function([], theano.dot(m, v), mode=mode_blas_opt)
# Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot)
......@@ -967,34 +1100,36 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
def t_gemv1(m_shp):
''' test vector2+dot(matrix,vector1) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(m_shp[1],)), dtype='float32'))
v1 = theano.shared(numpy.array(rng.uniform(size=(m_shp[1],)
), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(m_shp[0],)), dtype='float32')
v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=m_shp), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=m_shp),
dtype='float32'))
f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt)
f = theano.function([], v2 + theano.dot(m, v1), mode=mode_blas_opt)
# Assert they produce the same output
assert numpy.allclose(f(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = f.maker.env.toposort()
assert len(topo)==1
assert len(topo) == 1
assert isinstance(topo[0].op, Gemv)
assert topo[0].op.inplace==False
assert topo[0].op.inplace == False
#test the inplace version
g = theano.function([], [], updates={v2:v2+theano.dot(m,v1)}
, mode = mode_blas_opt)
g = theano.function([], [], updates={v2: v2 + theano.dot(m, v1)},
mode=mode_blas_opt)
# Assert they produce the same output
g()
assert numpy.allclose(v2.get_value(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = g.maker.env.toposort()
assert len(topo)==1
assert len(topo) == 1
assert isinstance(topo[0].op, Gemv)
if config.mode != 'FAST_COMPILE':
assert topo[0].op.inplace==True
assert topo[0].op.inplace == True
# Do the same tests with a matrix with strides in both dimensions
m.set_value(
......@@ -1008,40 +1143,42 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
def test_gemv1(self):
self.t_gemv1((3,2))
self.t_gemv1((0,2))
self.t_gemv1((3,0))
self.t_gemv1((0,0))
self.t_gemv1((3, 2))
self.t_gemv1((0, 2))
self.t_gemv1((3, 0))
self.t_gemv1((0, 0))
def test_gemv2(self):
''' test vector2+dot(vector1,matrix) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)),
dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(3,)), dtype='float32')
v2 = theano.shared(v2_orig )
m = theano.shared(numpy.array(rng.uniform(size=(2,3)), dtype='float32'))
v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=(2, 3)),
dtype='float32'))
f = theano.function([], v2+theano.dot(v1,m), mode = mode_blas_opt)
f = theano.function([], v2 + theano.dot(v1, m), mode=mode_blas_opt)
# Assert they produce the same output
assert numpy.allclose(f(),
numpy.dot(v1.get_value(), m.get_value()) + v2.get_value())
topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
assert topo[-1].op.inplace==False
assert sum(isinstance(node.op, Gemv) for node in topo) == 1
assert topo[-1].op.inplace == False
#test the inplace version
g = theano.function([], [], updates={v2:v2+theano.dot(v1,m)}
, mode = mode_blas_opt)
g = theano.function([], [], updates={v2: v2 + theano.dot(v1, m)},
mode=mode_blas_opt)
# Assert they produce the same output
g()
assert numpy.allclose(v2.get_value(),
numpy.dot(v1.get_value(), m.get_value()) + v2_orig)
topo = g.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1
assert sum(isinstance(node.op, Gemv) for node in topo) == 1
if config.mode != 'FAST_COMPILE':
assert topo[-1].op.inplace==True
assert topo[-1].op.inplace == True
# Do the same tests with a matrix with strides in both dimensions
m.set_value(
......@@ -1066,7 +1203,7 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
f = theano.function([A, x, y], z)
# Matrix value
A_val = numpy.ones((5,3), dtype=config.floatX)
A_val = numpy.ones((5, 3), dtype=config.floatX)
# Different vector length
ones_3 = numpy.ones(3, dtype=config.floatX)
ones_4 = numpy.ones(4, dtype=config.floatX)
......@@ -1090,7 +1227,7 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
def matrixmultiply(a, b):
if len(b.shape) == 1:
b_is_vector = True
b = b[:,newaxis]
b = b[:, newaxis]
else:
b_is_vector = False
assert a.shape[1] == b.shape[0]
......@@ -1099,8 +1236,8 @@ def matrixmultiply(a, b):
for j in xrange(b.shape[1]):
s = 0
for k in xrange(a.shape[1]):
s += a[i,k] * b[k, j]
c[i,j] = s
s += a[i, k] * b[k, j]
c[i, j] = s
if b_is_vector:
c = c.reshape((a.shape[0],))
return c
......@@ -1110,23 +1247,25 @@ class BaseGemv(object):
mode = mode_blas_opt # can be overridden with self.mode
shared = staticmethod(theano.shared)
def get_data(self,x_stride=1,y_stride=1):
def get_data(self, x_stride=1, y_stride=1):
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
mult = array(1, dtype=self.dtype)
if self.dtype in [complex64,complex128]:
if self.dtype in [complex64, complex128]:
mult = array(1 + 1j, dtype=self.dtype)
alpha = array(1., dtype=self.dtype) * mult
beta = array(1., dtype=self.dtype) * mult
a = rng.randn(3,3).astype(self.dtype) * mult
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
return alpha,beta,a,x,y
a = rng.randn(3, 3).astype(self.dtype) * mult
x = arange(shape(a)[0] * x_stride, dtype=self.dtype) * mult
y = arange(shape(a)[1] * y_stride, dtype=self.dtype) * mult
return alpha, beta, a, x, y
def test_simple(self):
alpha, beta, a, x, y = [ self.shared(value) for value in self.get_data() ]
desired_oy = alpha.get_value() * matrixmultiply(a.get_value(),x.get_value()) + beta.get_value() * y.get_value()
alpha, beta, a, x, y = [self.shared(value)
for value in self.get_data()]
desired_oy = alpha.get_value() * matrixmultiply(a.
get_value(), x.get_value()) + beta.get_value() * y.get_value()
oy = alpha * T.dot(a,x) + beta * y
oy = alpha * T.dot(a, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1146,7 +1285,7 @@ class BaseGemv(object):
desired_oy = matrixmultiply(a_v, x_v)
oy = T.dot(a,x)
oy = T.dot(a, x)
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1155,15 +1294,15 @@ class BaseGemv(object):
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
def test_simple_transpose(self):
vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v
desired_oy = alpha_v * matrixmultiply(transpose(a_v),
x_v) + beta_v * y_v
oy = alpha * T.dot(a.T,x)+beta*y
oy = alpha * T.dot(a.T, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1173,13 +1312,13 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v)
def test_x_stride(self):
vs = self.get_data(x_stride = 2)
vs = self.get_data(x_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(a_v,x_v[::2])+beta_v*y_v
desired_oy = alpha_v * matrixmultiply(a_v, x_v[::2]) + beta_v * y_v
oy = alpha * T.dot(a,x[::2])+beta*y
oy = alpha * T.dot(a, x[::2]) + beta * y
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1189,13 +1328,14 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v)
def test_x_stride_transpose(self):
vs = self.get_data(x_stride = 2)
vs = self.get_data(x_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v[::2])+beta_v*y_v
desired_oy = alpha_v * matrixmultiply(transpose(a_v), x_v[::
2]) + beta_v * y_v
oy = alpha * T.dot(a.T,x[::2])+beta*y
oy = alpha * T.dot(a.T, x[::2]) + beta * y
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1205,13 +1345,13 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v)
def test_y_stride(self):
vs = self.get_data(y_stride = 2)
vs = self.get_data(y_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(a_v,x_v)+beta_v*y_v[::2]
desired_oy = alpha_v * matrixmultiply(a_v, x_v) + beta_v * y_v[::2]
oy = alpha * T.dot(a,x)+beta*y[::2]
oy = alpha * T.dot(a, x) + beta * y[::2]
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1221,13 +1361,14 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v)
def test_y_stride_transpose(self):
vs = self.get_data(y_stride = 2)
vs = self.get_data(y_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v[::2]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),
x_v) + beta_v * y_v[::2]
oy = alpha * T.dot(a.T,x)+beta*y[::2]
oy = alpha * T.dot(a.T, x) + beta * y[::2]
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1239,15 +1380,16 @@ class BaseGemv(object):
def test_a_strides(self):
vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
a_v = a_v[::-1, ::-1]
a.set_value(
a.get_value(borrow=True, return_internal_type=True)[::-1, ::-1],
a.get_value(borrow=True,
return_internal_type=True)[::-1, ::-1],
borrow=True)
desired_oy = alpha_v * matrixmultiply(a_v,x_v)+beta_v*y_v
desired_oy = alpha_v * matrixmultiply(a_v, x_v) + beta_v * y_v
oy = alpha * T.dot(a,x)+beta*y
oy = alpha * T.dot(a, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1259,15 +1401,17 @@ class BaseGemv(object):
def test_a_strides_transpose(self):
vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ]
alpha, beta, a, x, y = [self.shared(v) for v in vs]
a_v = a_v[::-1, ::-1]
a.set_value(
a.get_value(borrow=True, return_internal_type=True)[::-1, ::-1],
a.get_value(borrow=True,
return_internal_type=True)[::-1, ::-1],
borrow=True)
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v
desired_oy = alpha_v * matrixmultiply(transpose(a_v),
x_v) + beta_v * y_v
oy = alpha * T.dot(a.T,x)+beta*y
oy = alpha * T.dot(a.T, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode)
......@@ -1332,6 +1476,7 @@ class TestDgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
## Tests for Ger
###############################################################################
class TestGer_make_node(TestCase):
def setUp(self):
self.iv = T.tensor(dtype='int32', broadcastable=(False,))
......@@ -1439,19 +1584,21 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
""" test local_gemm_to_ger opt"""
assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'),
self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(0)).owner)
def test_b_1_triggers_ger(self):
""" test local_gemm_to_ger opt"""
assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'),
self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1)).owner)
def test_b_other_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt"""
assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'),
self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_outer(self):
......@@ -1555,6 +1702,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
f(numpy.random.rand(4).astype(self.dtype),
numpy.random.rand(5).astype(self.dtype))
class TestBlasStrides(TestCase):
dtype = 'float64'
shared = staticmethod(tensor._shared)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论