提交 be95a5ce authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #684 from nouiz/blas

Improve gemm optimizer to never accidentally duplicate work.
...@@ -33,7 +33,8 @@ from optdb import \ ...@@ -33,7 +33,8 @@ from optdb import \
EquilibriumDB, SequenceDB, ProxyDB EquilibriumDB, SequenceDB, ProxyDB
from toolbox import \ from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener Bookkeeper, History, Validator, ReplaceValidate, NodeFinder,\
PrintListener, ReplacementDidntRemovedError
from type import \ from type import \
Type, Generic, generic Type, Generic, generic
......
...@@ -9,6 +9,15 @@ class AlreadyThere(Exception): ...@@ -9,6 +9,15 @@ class AlreadyThere(Exception):
pass pass
class ReplacementDidntRemovedError(Exception):
"""This exception should be thrown by replace_all_validate_remove
when an optimization wanted to remove a Variable or a Node from
the graph, but the replacement it gived didn't do that.
"""
pass
class Bookkeeper: class Bookkeeper:
def on_attach(self, env): def on_attach(self, env):
...@@ -91,12 +100,15 @@ class ReplaceValidate(History, Validator): ...@@ -91,12 +100,15 @@ class ReplaceValidate(History, Validator):
" or in conflict with another plugin.") " or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
env.replace_all_validate_remove = partial(
self.replace_all_validate_remove, env)
def on_detach(self, env): def on_detach(self, env):
History.on_detach(self, env) History.on_detach(self, env)
Validator.on_detach(self, env) Validator.on_detach(self, env)
del env.replace_validate del env.replace_validate
del env.replace_all_validate del env.replace_all_validate
del env.replace_all_validate_remove
def replace_validate(self, env, r, new_r, reason=None): def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)], reason=reason) self.replace_all_validate(env, [(r, new_r)], reason=reason)
...@@ -121,6 +133,28 @@ class ReplaceValidate(History, Validator): ...@@ -121,6 +133,28 @@ class ReplaceValidate(History, Validator):
except Exception, e: except Exception, e:
env.revert(chk) env.revert(chk)
raise raise
return chk
def replace_all_validate_remove(self, env, replacements,
remove, reason=None):
"""As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. It also print a warning.
"""
chk = env.replace_all_validate(replacements, reason)
for rm in remove:
if rm in env.nodes or rm in env.variables:
env.revert(chk)
out = sys.stderr
print >> out, (
"WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization."
" Your function runs correctly, but it would be"
" appreciated if you submit this problem to the mailing"
" list theano-users so that we can fix it.")
print >> out, reason, replacements
raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper): class NodeFinder(dict, Bookkeeper):
......
...@@ -133,8 +133,10 @@ import numpy.distutils ...@@ -133,8 +133,10 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, DestroyHandler, from theano.gof import (utils, Op, view_roots, DestroyHandler,
local_optimizer, Optimizer, local_optimizer, Optimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer, Apply) InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
...@@ -1022,7 +1024,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1022,7 +1024,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M #print 'GEMM 0', rval, beta, L, alpha, M
return rval return rval, M
# it also might be the case that there is a dimshuffle between the + # it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things. # and the dot22. local_dot_to_dot22 in particular will put in such things.
...@@ -1035,7 +1037,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1035,7 +1037,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle(0, 'x'), g = gemm_no_inplace(L.dimshuffle(0, 'x'),
alpha, MMl, MMr, beta) alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)] rval = [g.dimshuffle(0)]
return rval return rval, MM
if tuple(M.owner.op.new_order) == (1,): if tuple(M.owner.op.new_order) == (1,):
# it is making a row MM into a vector # it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22: if MM.owner and MM.owner.op == _dot22:
...@@ -1043,7 +1045,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1043,7 +1045,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 0), g = gemm_no_inplace(L.dimshuffle('x', 0),
alpha, MMl, MMr, beta) alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)] rval = [g.dimshuffle(1)]
return rval return rval, MM
if tuple(M.owner.op.new_order) == (): if tuple(M.owner.op.new_order) == ():
# it is making a row MM into a vector # it is making a row MM into a vector
if MM.owner and MM.owner.op == _dot22: if MM.owner and MM.owner.op == _dot22:
...@@ -1051,7 +1053,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1051,7 +1053,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
g = gemm_no_inplace(L.dimshuffle('x', 'x'), g = gemm_no_inplace(L.dimshuffle('x', 'x'),
alpha, MMl, MMr, beta) alpha, MMl, MMr, beta)
rval = [g.dimshuffle()] rval = [g.dimshuffle()]
return rval return rval, MM
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
...@@ -1085,7 +1087,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1085,7 +1087,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
if recurse_flip: if recurse_flip:
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False) return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False)
else: else:
return False return False, False
def _gemm_canonicalize(r, scale, rval, maxclients): def _gemm_canonicalize(r, scale, rval, maxclients):
...@@ -1250,7 +1252,8 @@ def _gemm_from_factored_list(lst): ...@@ -1250,7 +1252,8 @@ def _gemm_from_factored_list(lst):
#print 'TRYING', (s_i, M_i, s_j, M_j) #print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j) gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(s_i, M_i,
s_j, M_j)
#print 'GOT IT', gemm_of_sM_list #print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list: if gemm_of_sM_list:
def item_to_var(t): def item_to_var(t):
...@@ -1273,7 +1276,7 @@ def _gemm_from_factored_list(lst): ...@@ -1273,7 +1276,7 @@ def _gemm_from_factored_list(lst):
else: else:
rval = add_inputs rval = add_inputs
#print "RETURNING GEMM THIGN", rval #print "RETURNING GEMM THIGN", rval
return rval return rval, old_dot22
def _gemm_from_node2(node): def _gemm_from_node2(node):
...@@ -1301,7 +1304,7 @@ def _gemm_from_node2(node): ...@@ -1301,7 +1304,7 @@ def _gemm_from_node2(node):
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket. # but never made it into a trac ticket.
if rval and (rval[0].type == node.outputs[0].type): if rval and (rval[0][0].type == node.outputs[0].type):
return rval return rval
...@@ -1326,11 +1329,13 @@ class GemmOptimizer(Optimizer): ...@@ -1326,11 +1329,13 @@ class GemmOptimizer(Optimizer):
except InconsistencyError, e: except InconsistencyError, e:
continue continue
if new_outputs: if new_outputs:
new_outputs, old_dot22 = new_outputs
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
try: try:
env.replace_all_validate( env.replace_all_validate_remove(
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
reason='GemmOptimizer' [old_dot22],
reason='GemmOptimizer'
) )
did_something = True did_something = True
break break
...@@ -1338,6 +1343,8 @@ class GemmOptimizer(Optimizer): ...@@ -1338,6 +1343,8 @@ class GemmOptimizer(Optimizer):
# TODO: retry other applications of gemm (see comment # TODO: retry other applications of gemm (see comment
# in _gemm_from_node) # in _gemm_from_node)
pass pass
except ReplacementDidntRemovedError, e:
pass
class Dot22(GemmRelated): class Dot22(GemmRelated):
......
...@@ -15,7 +15,6 @@ from numpy.testing import assert_array_almost_equal ...@@ -15,7 +15,6 @@ from numpy.testing import assert_array_almost_equal
#from numpy.testing import dec #from numpy.testing import dec
#from numpy.testing.noseclasses import KnownFailureTest #from numpy.testing.noseclasses import KnownFailureTest
#from theano.tensor.blas import *
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
_is_real_matrix, _gemm_canonicalize, _is_real_matrix, _gemm_canonicalize,
_factor_canonicalized, Gemm, Gemv, _factor_canonicalized, Gemm, Gemv,
...@@ -46,6 +45,10 @@ def test_dot_eq(): ...@@ -46,6 +45,10 @@ def test_dot_eq():
assert T.Dot() == T.Dot() assert T.Dot() == T.Dot()
def sharedX(x, name):
return theano.shared(numpy.asarray(x, config.floatX), name=name)
class t_gemm(TestCase): class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it """This test suite is supposed to establish that gemm works as it
is supposed to. is supposed to.
...@@ -171,20 +174,22 @@ class t_gemm(TestCase): ...@@ -171,20 +174,22 @@ class t_gemm(TestCase):
self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0) self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0)
def test_factorised_scalar(self): def test_factorised_scalar(self):
a = T.dmatrix() a = T.matrix()
b = T.dmatrix() b = T.matrix()
c = T.dmatrix() c = T.matrix()
s = theano.shared(numpy.zeros((5, 5))) s = theano.shared(numpy.zeros((5, 5)).astype(config.floatX))
lr1 = T.constant(0.01).astype('float64') lr1 = T.constant(0.01).astype(config.floatX)
lr2 = T.constant(2).astype('float64') lr2 = T.constant(2).astype(config.floatX)
l2_reg = T.constant(0.0001).astype('float64') l2_reg = T.constant(0.0001).astype(config.floatX)
#test constant merge with gemm #test constant merge with gemm
f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) + f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) +
l2_reg * lr2 * s}, l2_reg * lr2 * s},
mode=mode_not_fast_compile).maker.env.toposort() mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)] #[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
# <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
# 2e-06)]
assert len(f) == 1 assert len(f) == 1
assert f[0].op == gemm_inplace assert f[0].op == gemm_inplace
...@@ -192,14 +197,19 @@ class t_gemm(TestCase): ...@@ -192,14 +197,19 @@ class t_gemm(TestCase):
f = theano.function([a, b], updates={s: lr1 * (T.dot(a, b) - f = theano.function([a, b], updates={s: lr1 * (T.dot(a, b) -
l2_reg * s)}, l2_reg * s)},
mode=mode_not_fast_compile).maker.env.toposort() mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, -2e-06)] #[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
# <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
# -2e-06)]
assert len(f) == 1 assert len(f) == 1
assert f[0].op == gemm_inplace assert f[0].op == gemm_inplace
#test factored scalar with merge and neg #test factored scalar with merge and neg
f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))}, f = theano.function([a, b],
updates={s: s - lr1 * (s * .0002 + T.dot(a, b))},
mode=mode_not_fast_compile).maker.env.toposort() mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 0.999998)] #[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01,
# <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
# 0.999998)]
assert len(f) == 1 assert len(f) == 1
assert f[0].op == gemm_inplace assert f[0].op == gemm_inplace
...@@ -291,7 +301,8 @@ class t_gemm(TestCase): ...@@ -291,7 +301,8 @@ class t_gemm(TestCase):
tx.set_value(y_T, borrow=True) tx.set_value(y_T, borrow=True)
f() f()
# test that the transposed version of multiplication gives same answer # test that the transposed version of multiplication gives
# same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T)) self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T))
t(C, A, B) t(C, A, B)
...@@ -330,12 +341,14 @@ class t_gemm(TestCase): ...@@ -330,12 +341,14 @@ class t_gemm(TestCase):
z_orig = z.copy() z_orig = z.copy()
z_after = numpy.zeros_like(z_orig) z_after = numpy.zeros_like(z_orig)
for i in xrange(3): for i in xrange(3):
z_after[:,:,i] = self._gemm(z[:,:,i], a, x[:,:,i], y[:,:,i], b) z_after[:, :, i] = self._gemm(z[:, :, i], a,
x[:, :, i], y[:, :, i], b)
tz, ta, tx, ty, tb = [shared(p) for p in z, a, x, y, b] tz, ta, tx, ty, tb = [shared(p) for p in z, a, x, y, b]
for i in xrange(3): for i in xrange(3):
f_i = inplace_func([], f_i = inplace_func([],
gemm_inplace(tz[:,:,i], ta, tx[:,:,i], ty[:,:,i], tb), gemm_inplace(tz[:, :, i],
ta, tx[:, :, i], ty[:, :, i], tb),
mode=compile.Mode(optimizer=None, linker=l)) mode=compile.Mode(optimizer=None, linker=l))
for j in xrange(3): for j in xrange(3):
# tz will not _always_ be overwritten, # tz will not _always_ be overwritten,
...@@ -347,30 +360,32 @@ class t_gemm(TestCase): ...@@ -347,30 +360,32 @@ class t_gemm(TestCase):
self.assertTrue( self.assertTrue(
_approx_eq(z_after[:, :, i], _approx_eq(z_after[:, :, i],
tz.get_value(borrow=True)[:,:,i]), tz.get_value(borrow=True)[:, :, i]),
(z_orig[:,:,i], z_after[:,:,i], (z_orig[:, :, i], z_after[:, :, i],
z[:,:,i], z_after[:,:,i] - z[:,:,i])) z[:, :, i], z_after[:, :, i] - z[:, :, i]))
tz_i = gemm_no_inplace(tz[:,:,i], ta, tx[:,:,i], ty[:,:,i], tb) tz_i = gemm_no_inplace(tz[:, :, i], ta, tx[
:, :, i], ty[:, :, i], tb)
g_i = theano.function([], tz_i, g_i = theano.function([], tz_i,
updates={tz:T.set_subtensor(tz[:,:,i], tz_i)}, updates={tz: T.set_subtensor(tz[:, :, i], tz_i)},
mode=compile.Mode(optimizer=None, linker=l)) mode=compile.Mode(optimizer=None, linker=l))
for j in xrange(3): for j in xrange(3):
g_i() g_i()
self.assertTrue( self.assertTrue(
_approx_eq(z_after[:,:,i], _approx_eq(z_after[:, :, i],
tz.get_value(borrow=True)[:,:,i]), tz.get_value(borrow=True)[:, :, i]),
(z_orig[:,:,i], z_after[:,:,i], (z_orig[:, :, i], z_after[:, :, i],
z[:,:,i], z_after[:,:,i] - z[:,:,i])) z[:, :, i], z_after[:, :, i] - z[:, :, i]))
t(C, A, B) t(C, A, B)
t(C.transpose((1,0,2)), A, B) t(C.transpose((1, 0, 2)), A, B)
t(C, A.transpose((1,0,2)), B, dt='float32') t(C, A.transpose((1, 0, 2)), B, dt='float32')
t(C, A, B.transpose((1,0,2))) t(C, A, B.transpose((1, 0, 2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B) t(C.transpose((1, 0, 2)), A.transpose((1, 0, 2)), B)
t(C, A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32') t(C, A.transpose((1, 0, 2)), B.transpose((1, 0, 2)), dt='float32')
t(C.transpose((1,0,2)), A, B.transpose((1,0,2))) t(C.transpose((1, 0, 2)), A, B.transpose((1, 0, 2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32') t(C.transpose((1, 0, 2)), A.transpose((1, 0, 2)), B.transpose((
1, 0, 2)), dt='float32')
def test_res_is_a(): def test_res_is_a():
...@@ -418,7 +433,7 @@ class t_as_scalar(TestCase): ...@@ -418,7 +433,7 @@ class t_as_scalar(TestCase):
def test3(self): def test3(self):
"""Test that it fails on nonscalar variables""" """Test that it fails on nonscalar variables"""
a = T.dmatrix() a = T.matrix()
self.assertTrue(None == _as_scalar(a)) self.assertTrue(None == _as_scalar(a))
self.assertTrue(None == _as_scalar(T.DimShuffle([False, False], self.assertTrue(None == _as_scalar(T.DimShuffle([False, False],
[0, 'x', 1])(a))) [0, 'x', 1])(a)))
...@@ -427,7 +442,7 @@ class t_as_scalar(TestCase): ...@@ -427,7 +442,7 @@ class t_as_scalar(TestCase):
class T_real_matrix(TestCase): class T_real_matrix(TestCase):
def test0(self): def test0(self):
self.assertTrue(_is_real_matrix(T.DimShuffle([False, False], self.assertTrue(_is_real_matrix(T.DimShuffle([False, False],
[1, 0])(T.dmatrix()))) [1, 0])(T.matrix())))
self.assertTrue(not _is_real_matrix(T.DimShuffle([False], self.assertTrue(not _is_real_matrix(T.DimShuffle([False],
['x', 0]) ['x', 0])
(T.dvector()))) (T.dvector())))
...@@ -438,32 +453,38 @@ def fail(msg): ...@@ -438,32 +453,38 @@ def fail(msg):
assert False assert False
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting """This test suite ensures that Gemm is inserted where it belongs, and
functions compute the same things as the originals.""" that the resulting functions compute the same things as the
originals.
"""
def XYZab(): def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() return T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
class Failure(Exception): class Failure(Exception):
pass pass
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0): def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
max_graphlen=0, expected_nb_gemm=1):
try: try:
f = inplace_func( f = inplace_func(
[Param(ii, mutable=True, allow_downcast=True) for ii in i], [Param(ii, mutable=True, allow_downcast=True) for ii in i],
o, o,
mode='FAST_RUN', mode='FAST_RUN',
on_unused_input='ignore') on_unused_input='ignore')
at_least_one_gemm = False nb_gemm = 0
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph') raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22: if node.op == _dot22:
raise Failure('_dot22 not changed to gemm_inplace in graph') raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace: if node.op == gemm_inplace:
at_least_one_gemm = True nb_gemm += 1
assert at_least_one_gemm assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None), g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore') allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.env.nodes: for node in g.maker.env.nodes:
...@@ -476,11 +497,16 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0): ...@@ -476,11 +497,16 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen) assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes]) r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0])) max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8: eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =', raise Failure('GEMM is computing the wrong output. max_rel_err =',
max_abs_err) max_abs_err)
except Failure: except Failure:
...@@ -491,62 +517,73 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0): ...@@ -491,62 +517,73 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
def test_gemm_opt0(): def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated""" """Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = XYZab() X, Y, Z, a, b = XYZab()
just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b]) just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z]) just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) + b * Z])
just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b]) just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a - Z * b])
just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z]) just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) - b * Z])
just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [b * Z - a * T.dot(X, Y)])
#with transposes (transposes should be pushed through dot in canonicalize) #with transposes (transposes should be pushed through dot in canonicalize)
just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)]) just_gemm([X, Y, Z, a, b], [b * Z.T - a * T.dot(Y.T, X.T)])
just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T]) just_gemm([X, Y, Z, a, b], [b * Z.T + a * b * T.dot(X, Y).T])
just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y).T], just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y).T],
ishapes=[(5,3), (3,4), (4,5), (), ()]) ishapes=[(5, 3), (3, 4), (4, 5), (), ()])
#with N multiplications instead of just one #with N multiplications instead of just one
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b]) just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [Z * b + T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [Z + a * b * a * T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b]) just_gemm([X, Y, Z, a, b], [(b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [Z - T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [Z * b - T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)]) just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
def test_gemm_opt_double_gemm(): def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder""" """This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar() R, S, c = T.matrix(), T.matrix(), T.scalar()
just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T], just_gemm([X, Y, Z, a, b, R, S, c],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]) [Z * c + a * T.dot(X, Y) + b * T.dot(R, S).T],
ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()] expected_nb_gemm=2)
i = [X,Y,Z,a,b, R, S, c]
o = [(a * T.dot(X,Y) ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()]
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))] i = [X, Y, Z, a, b, R, S, c]
o = [(a * T.dot(X, Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
try: try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, f = inplace_func([Param(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore') mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: raise Failure('dot in graph') if node.op == T.dot:
if node.op == _dot22: raise Failure('_dot22 in graph') raise Failure('dot in graph')
if node.op == _dot22:
raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None), g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore') on_unused_input='ignore')
#for node in g.maker.env.nodes: #for node in g.maker.env.nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph') # if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes]) r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0])) max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8: eps = 1.0e-8
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err) if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure(
'GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure: except Failure:
for node in f.maker.env.toposort(): for node in f.maker.env.toposort():
print 'GRAPH', node print 'GRAPH', node
...@@ -554,8 +591,10 @@ def test_gemm_opt_double_gemm(): ...@@ -554,8 +591,10 @@ def test_gemm_opt_double_gemm():
def test_gemm_canonicalize(): def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') 'a'), T.scalar('b')
R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
'c'), T.scalar('d')
u = T.row('u') u = T.row('u')
v = T.vector('v') v = T.vector('v')
w = T.col('w') w = T.col('w')
...@@ -584,7 +623,7 @@ def test_gemm_canonicalize(): ...@@ -584,7 +623,7 @@ def test_gemm_canonicalize():
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = [] can = []
_gemm_canonicalize(a*X + Y - b*Z*c, 1.0, can, 0) _gemm_canonicalize(a * X + Y - b * Z * c, 1.0, can, 0)
assert can[0] == (a, X) assert can[0] == (a, X)
assert can[1] == (1.0, Y) assert can[1] == (1.0, Y)
assert can[2][0].owner.op == T.mul assert can[2][0].owner.op == T.mul
...@@ -593,7 +632,7 @@ def test_gemm_canonicalize(): ...@@ -593,7 +632,7 @@ def test_gemm_canonicalize():
assert can[2][0].owner.inputs[1] == b assert can[2][0].owner.inputs[1] == b
can = [] can = []
_gemm_canonicalize((-d) * X - (a*X + Y - b*Z*c), 1.0, can, 0) _gemm_canonicalize((-d) * X - (a * X + Y - b * Z * c), 1.0, can, 0)
#print can #print can
assert can[0][0].owner.op == T.neg assert can[0][0].owner.op == T.neg
assert can[0][0].owner.inputs[0] == d assert can[0][0].owner.inputs[0] == d
...@@ -602,14 +641,18 @@ def test_gemm_canonicalize(): ...@@ -602,14 +641,18 @@ def test_gemm_canonicalize():
assert can[1][0].owner.inputs[0] == a assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y) assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == T.mul assert can[3][0].owner.op == T.mul
assert can[3][0].owner.inputs == [c,b] assert can[3][0].owner.inputs == [c, b]
def test_gemm_factor(): def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') 'a'), T.scalar('b')
R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
'c'), T.scalar('d')
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)]) assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)]) assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
def test_upcasting_scalar_nogemm(): def test_upcasting_scalar_nogemm():
# Test that the optimization does not crash when the scale has an incorrect # Test that the optimization does not crash when the scale has an incorrect
...@@ -643,119 +686,183 @@ def test_upcasting_scalar_nogemm(): ...@@ -643,119 +686,183 @@ def test_upcasting_scalar_nogemm():
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0 assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
#theano.printing.debugprint(f, print_type=True) #theano.printing.debugprint(f, print_type=True)
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
just_gemm([X,Y,Z,R,S,U,a,b,c,d], def test_gemm_nested():
[a * Z - b * (c*T.dot(X,Y) + d*Z)], X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()], 'a'), T.scalar('b')
R, S, U, c, d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar(
'c'), T.scalar('d')
just_gemm([X, Y, Z, R, S, U, a, b, c, d],
[a * Z - b * (c * T.dot(X, Y) + d * Z)],
ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4), (
2, 4), (), (), (), ()],
max_graphlen=1) max_graphlen=1)
#print "---------------------" #print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d], just_gemm([X, Y, Z, R, S, U, a, b, c, d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*Z)], [a * Z - b * (c * T.dot(X, Y) + d * Z + c * Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()], ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4), (
2, 4), (), (), (), ()],
max_graphlen=1) max_graphlen=1)
#print "---------------------" #print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d], just_gemm([X, Y, Z, R, S, U, a, b, c, d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*U)], [a * Z - b * (c * T.dot(X, Y) + d * Z + c * U)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()], ishapes=[(2, 3), (3, 4), (2, 4), (2, 3), (3, 4), (
2, 4), (), (), (), ()],
max_graphlen=3) max_graphlen=3)
def test_gemm_opt_wishlist(): def test_gemm_opt_wishlist():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
#with >2 additions of the same T.dot(X,Y term #with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)]) just_gemm([X, Y, Z, a, b],
[(b * b) * Z * a + (a * a) * T.dot(X, Y) + b * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y) + T.dot(X, Y)])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
def test_gemm_with_vector(): def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated. """Many subgraphs whose dots can be eliminated. This adds a
This adds a vector two the previous test, which triggers the long-sought GEMM bug. vector two the previous test, which triggers the long-sought GEMM
bug.
""" """
X,Y,Z,a,b = XYZab() X, Y, Z, a, b = XYZab()
v = T.vector() v = T.vector()
def my_just_gemm(o): def my_just_gemm(o):
i = [X,Y,Z,a,b,v] i = [X, Y, Z, a, b, v]
ishapes = [(4,3), (3,5), (4,5), (), (), (5,)] ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, )]
rval = just_gemm(i, o, ishapes=ishapes) rval = just_gemm(i, o, ishapes=ishapes)
my_just_gemm([v + T.dot(X,Y) * a + Z * b]) my_just_gemm([v + T.dot(X, Y) * a + Z * b])
my_just_gemm([v + a * T.dot(X,Y) + b * Z]) my_just_gemm([v + a * T.dot(X, Y) + b * Z])
my_just_gemm([v + b * Z + a * T.dot(X,Y)]) my_just_gemm([v + b * Z + a * T.dot(X, Y)])
my_just_gemm([v + T.dot(X,Y) * a - Z * b]) my_just_gemm([v + T.dot(X, Y) * a - Z * b])
my_just_gemm([v + a * T.dot(X,Y) - b * Z]) my_just_gemm([v + a * T.dot(X, Y) - b * Z])
my_just_gemm([v + b * Z - a * T.dot(X,Y)]) my_just_gemm([v + b * Z - a * T.dot(X, Y)])
#with N multiplications instead of just one #with N multiplications instead of just one
my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X,Y) * b]) my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
my_just_gemm([v + Z + T.dot(X,Y)]) my_just_gemm([v + Z + T.dot(X, Y)])
my_just_gemm([v + Z*b + T.dot(X,Y)]) my_just_gemm([v + Z * b + T.dot(X, Y)])
my_just_gemm([v + Z + a*b*a*T.dot(X,Y)]) my_just_gemm([v + Z + a * b * a * T.dot(X, Y)])
my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X,Y) * b]) my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
my_just_gemm([Z - T.dot(X,Y) + v]) my_just_gemm([Z - T.dot(X, Y) + v])
my_just_gemm([Z*b - T.dot(X,Y) + v]) my_just_gemm([Z * b - T.dot(X, Y) + v])
my_just_gemm([Z - a*b*a*T.dot(X,Y) + v]) my_just_gemm([Z - a * b * a * T.dot(X, Y) + v])
def test_gemm_opt_vector_stuff(): def test_gemm_opt_vector_stuff():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
u,v = T.dvector(), T.dvector() u, v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u,v), mode='FAST_RUN') f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.env.nodes]: if gemm_inplace in [n.op for n in f.maker.env.nodes]:
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
f = inplace_func([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN') f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.env.nodes]): if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
def test_gemm_unrolled():
"""This test that the gemm optimizer remove the dot22 that was
present in the graph. Otherwise, this add a gemm, but still
compute the dot22.
This was not always the case in the with this the following code.
"""
batch_size = 100
rep_size = 40
rng = numpy.random.RandomState([1, 2, 3])
for num_rounds in range(1, 10):
W = sharedX(rng.randn(rep_size, rep_size), name='W')
V = sharedX(numpy.zeros((batch_size, rep_size)), name='V')
H = sharedX(numpy.zeros((batch_size, rep_size)), name='H')
G = sharedX(numpy.zeros((batch_size, rep_size)), name='G')
init_V = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_V')
init_H = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_H')
cur_V = V
cur_H = H
def update_V(cur_H):
return T.nnet.sigmoid(T.dot(cur_H, W.T))
def update_H(cur_V):
return T.nnet.sigmoid(T.dot(cur_V, W) + T.dot(G, W.T))
for i in xrange(num_rounds):
cur_V = update_V(cur_H)
cur_H = update_H(cur_V)
unrolled_theano = theano.function([], updates={V: cur_V, H: cur_H},
name='unrolled_theano')
nb_dot = sum([1 for node in unrolled_theano.maker.env.toposort()
if isinstance(node.op, (theano.tensor.Dot,
theano.tensor.blas.Dot22,
theano.tensor.blas.Gemm))])
# Each num_rounds add 3 dot, but one of them is always the same.
# So the final graph should have 1 + 2* num_rounds dot varient op.
assert nb_dot == num_rounds * 2 + 1, nb_dot
unrolled_theano()
def test_inplace0(): def test_inplace0():
#should fail to insert gemm_inplace because gemm_inplace would create cycles #should fail to insert gemm_inplace because gemm_inplace would
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') #create cycles
R, S, c = T.dmatrix('R'), T.dmatrix('S'), T.dscalar('c') X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
'a'), T.scalar('b')
R, S, c = T.matrix('R'), T.matrix('S'), T.scalar('c')
f = inplace_func([Z, b, R, S], f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN') [Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.env.nodes]): if (gemm_inplace in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0]) print pp(f.maker.env.outputs[0])
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
assert gemm_no_inplace in [n.op for n in f.maker.env.nodes] assert gemm_no_inplace in [n.op for n in f.maker.env.nodes]
# gemm_inplace should be inserted here, to work in-place on Z*c # gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X,Y,Z,a,b, R, S, c], f = inplace_func([X, Y, Z, a, b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)], [Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
mode='FAST_RUN') mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]): if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
theano.printing.debugprint(f) theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph') raise Failure('no gemm_inplace in graph')
def test_inplace1(): def test_inplace1():
X,Y,Z,a,b = XYZab() X, Y, Z, a, b = XYZab()
# with > 2 terms in the overall addition # with > 2 terms in the overall addition
f = inplace_func([X, Y, Z], f = inplace_func([X, Y, Z],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN') [Z + Z + T.dot(X, Y)], mode='FAST_RUN')
#theano.printing.debugprint(f) #theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input # it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22(): def test_dot22():
for dtype1 in ['float32', 'float64', 'complex64', 'complex128']: for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
a = T.matrix(dtype=dtype1) a = T.matrix(dtype=dtype1)
for dtype2 in ['float32', 'float64', 'complex64', 'complex128']: for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
b = T.matrix(dtype=dtype2) b = T.matrix(dtype=dtype2)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt) f = theano.function([a, b], T.dot(a, b), mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if dtype1 == dtype2: if dtype1 == dtype2:
assert _dot22 in [x.op for x in topo], (dtype1,dtype2) assert _dot22 in [x.op for x in topo], (dtype1, dtype2)
else: else:
assert T.dot in [x.op for x in topo], (dtype1,dtype2) assert T.dot in [x.op for x in topo], (dtype1, dtype2)
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
av=rng.uniform(size=a_shp).astype(dtype1) av = rng.uniform(size=a_shp).astype(dtype1)
bv=rng.uniform(size=b_shp).astype(dtype2) bv = rng.uniform(size=b_shp).astype(dtype2)
f(av,bv) f(av, bv)
cmp((3, 4), (4, 5)) cmp((3, 4), (4, 5))
cmp((0, 4), (4, 5)) cmp((0, 4), (4, 5))
...@@ -764,11 +871,13 @@ def test_dot22(): ...@@ -764,11 +871,13 @@ def test_dot22():
cmp((0, 4), (4, 0)) cmp((0, 4), (4, 0))
cmp((0, 0), (0, 0)) cmp((0, 0), (0, 0))
def test_dot22scalar(): def test_dot22scalar():
## including does not seem to work for 'local_dot_to_dot22' and ## including does not seem to work for 'local_dot_to_dot22' and
## 'local_dot22_to_dot22scalar' ## 'local_dot22_to_dot22scalar'
## TODO: exclude other optimizations in BlasOpt? ## TODO: exclude other optimizations in BlasOpt?
#m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize') #m = theano.compile.get_default_mode().including('local_dot_to_dot22',
# 'local_dot22_to_dot22scalar','specialize')
#m = theano.compile.get_default_mode().including('BlasOpt', 'specialize') #m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
for dtype1 in ['complex64', 'complex128']: for dtype1 in ['complex64', 'complex128']:
...@@ -784,88 +893,111 @@ def test_dot22scalar(): ...@@ -784,88 +893,111 @@ def test_dot22scalar():
def check_dot22scalar(func, len_topo_scalar=-1): def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.env.toposort() topo = func.maker.env.toposort()
ops = [x.op for x in topo] ops = [x.op for x in topo]
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1, dtype2) dtype4_upcast = theano.scalar.upcast(dtype4, dtype1,
dtype2)
if dtype1 == dtype2 == dtype3 == dtype4_upcast: if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar>0: if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2, dtype3, dtype4) assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast: elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0): if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2, dtype3, dtype4) assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
else: else:
# Currently there is a problem of optimization order # Currently there is a problem of
# The constant get upcasted to float64 before we try to merge it # optimization order The constant get
# with the dot22 of float32. So this prevent the merge. # upcasted to float64 before we try to
assert _dot22scalar in ops or _dot22 in ops, (dtype1, dtype2, dtype3, dtype4) # merge it with the dot22 of
# float32. So this prevent the merge.
assert _dot22scalar in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2: elif dtype1 == dtype2:
assert _dot22 in ops, (dtype1, dtype2, dtype3, dtype4) assert _dot22 in ops, (dtype1, dtype2,
dtype3, dtype4)
else: else:
assert T.dot in ops, (dtype1, dtype2, dtype3, dtype4) assert T.dot in ops, (dtype1, dtype2,
dtype3, dtype4)
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5, 5)):
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5,5)): av = rng.uniform(size=a_shp).astype(dtype1)
av=rng.uniform(size=a_shp).astype(dtype1) bv = rng.uniform(size=b_shp).astype(dtype2)
bv=rng.uniform(size=b_shp).astype(dtype2) cv = rng.uniform(size=c_shp).astype(dtype3)
cv=rng.uniform(size=c_shp).astype(dtype3) sv = rng.uniform(size=sqr_shp).astype(dtype1)
sv=rng.uniform(size=sqr_shp).astype(dtype1)
if False: if False:
f = theano.function([a,b],cst*T.dot(a,b),mode=mode_blas_opt) f = theano.function([a, b], cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
check_dot22scalar(f, 1) check_dot22scalar(f, 1)
f(av,bv) f(av, bv)
if True: if True:
f = theano.function([a,b,c],cst*c*T.dot(a,b),mode=mode_blas_opt) f = theano.function([a, b, c],
cst * c * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(av,bv,cv) f(av, bv, cv)
f = theano.function([a,b,c],c * cst*T.dot(a,b),mode=mode_blas_opt) f = theano.function([a, b, c],
c * cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(av,bv,cv) f(av, bv, cv)
## Here, canonicalize also seems needed ## Here, canonicalize also seems needed
## TODO: add only the optimizations needed? ## TODO: add only the optimizations needed?
m2 = mode_blas_opt.including('canonicalize') m2 = mode_blas_opt.including('canonicalize')
f = theano.function([a,b,c],cst2 *c * cst*T.dot(a,b),mode=m2) f = theano.function([a, b, c],
cst2 * c * cst * T.dot(a, b),
mode=m2)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(av,bv,cv) f(av, bv, cv)
if dtype1 == dtype2 == dtype3: if dtype1 == dtype2 == dtype3:
f = theano.function([a,b,c],c * cst*a*T.dot(a,b),mode=m2) f = theano.function([a, b, c],
c * cst * a * T.dot(a, b),
mode=m2)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(sv,sv,sv) f(sv, sv, sv)
f = theano.function([a,b,c],cst*c *a*T.dot(a,b),mode=mode_blas_opt) f = theano.function([a, b, c],
cst * c * a * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together... #currently the canonizer don't always
# dot22scalar optimizer does not do a recursive search # merge all Mul together... dot22scalar
# therefore, it doesn't find potential matches of the scalar. # optimizer does not do a recursive search
# TODO: combine with the 'canonicalization' that is part of the Gemm optimizer. # therefore, it doesn't find potential
# matches of the scalar. TODO: combine
# with the 'canonicalization' that is part
# of the Gemm optimizer.
# #
# assert _dot22scalar in [x.op for x in topo] # assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2 # assert len(topo)==2
f(sv,sv,sv) f(sv, sv, sv)
f = theano.function([a,b,c],c * a*cst*T.dot(a,b),mode=m2) f = theano.function([a, b, c],
c * a * cst * T.dot(a, b),
mode=m2)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 2)
f(sv,sv,sv) f(sv, sv, sv)
cmp((3,4),(4,5),(3,5)) cmp((3, 4), (4, 5), (3, 5))
cmp((0,4),(4,5),(0,5)) cmp((0, 4), (4, 5), (0, 5))
cmp((3,0),(0,5),(3,5)) cmp((3, 0), (0, 5), (3, 5))
cmp((3,4),(4,0),(3,0),(0,0)) cmp((3, 4), (4, 0), (3, 0), (0, 0))
cmp((0,4),(4,0),(0,0)) cmp((0, 4), (4, 0), (0, 0))
cmp((0,0),(0,0),(0,0)) cmp((0, 0), (0, 0), (0, 0))
def test_dot22scalar_cast(): def test_dot22scalar_cast():
...@@ -889,19 +1021,20 @@ def test_dot22scalar_cast(): ...@@ -889,19 +1021,20 @@ def test_dot22scalar_cast():
def test_dot_w_self(): def test_dot_w_self():
# This can trigger problems in the optimization because what would normally be a gemm must # This can trigger problems in the optimization because what would
# not be because the output is aliased to one of the inputs. # normally be a gemm must not be because the output is aliased to
# one of the inputs.
A = shared(value=numpy.ones((2,2))) A = shared(value=numpy.ones((2, 2)))
B = T.matrix() B = T.matrix()
p = T.dot(A,A)*B p = T.dot(A, A) * B
grad = T.grad(T.mean(p), A) grad = T.grad(T.mean(p), A)
f = theano.function([B], p, updates={A : A - grad}) f = theano.function([B], p, updates={A: A - grad})
# tests correctness in debugmode # tests correctness in debugmode
f(numpy.asarray([[0,1], [2,3]], dtype=config.floatX)) f(numpy.asarray([[0, 1], [2, 3]], dtype=config.floatX))
############################################################################### ###############################################################################
...@@ -927,8 +1060,9 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -927,8 +1060,9 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
''' Test vector dot matrix ''' ''' Test vector dot matrix '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32')) v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=(2,3)), dtype='float32')) m = theano.shared(numpy.array(rng.uniform(size=(2, 3)),
f = theano.function([], theano.dot(v,m), mode=mode_blas_opt) dtype='float32'))
f = theano.function([], theano.dot(v, m), mode=mode_blas_opt)
# Assert that the dot was optimized somehow # Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot) self.assertFunctionContains0(f, T.dot)
...@@ -942,14 +1076,13 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -942,14 +1076,13 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
borrow=True) borrow=True)
assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value())) assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value()))
def test_dot_mv(self): def test_dot_mv(self):
''' Test matrix dot vector ''' ''' Test matrix dot vector '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32')) v = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32'))
m = theano.shared(numpy.array(rng.uniform(size=(3,2)), m = theano.shared(numpy.array(rng.uniform(size=(3, 2)),
dtype='float32')) dtype='float32'))
f = theano.function([], theano.dot(m,v), mode=mode_blas_opt) f = theano.function([], theano.dot(m, v), mode=mode_blas_opt)
# Assert that the dot was optimized somehow # Assert that the dot was optimized somehow
self.assertFunctionContains0(f, T.dot) self.assertFunctionContains0(f, T.dot)
...@@ -967,34 +1100,36 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -967,34 +1100,36 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
def t_gemv1(m_shp): def t_gemv1(m_shp):
''' test vector2+dot(matrix,vector1) ''' ''' test vector2+dot(matrix,vector1) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(m_shp[1],)), dtype='float32')) v1 = theano.shared(numpy.array(rng.uniform(size=(m_shp[1],)
), dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(m_shp[0],)), dtype='float32') v2_orig = numpy.array(rng.uniform(size=(m_shp[0],)), dtype='float32')
v2 = theano.shared(v2_orig) v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=m_shp), dtype='float32')) m = theano.shared(numpy.array(rng.uniform(size=m_shp),
dtype='float32'))
f = theano.function([], v2+theano.dot(m,v1), mode = mode_blas_opt) f = theano.function([], v2 + theano.dot(m, v1), mode=mode_blas_opt)
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), assert numpy.allclose(f(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig) numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
assert len(topo)==1 assert len(topo) == 1
assert isinstance(topo[0].op, Gemv) assert isinstance(topo[0].op, Gemv)
assert topo[0].op.inplace==False assert topo[0].op.inplace == False
#test the inplace version #test the inplace version
g = theano.function([], [], updates={v2:v2+theano.dot(m,v1)} g = theano.function([], [], updates={v2: v2 + theano.dot(m, v1)},
, mode = mode_blas_opt) mode=mode_blas_opt)
# Assert they produce the same output # Assert they produce the same output
g() g()
assert numpy.allclose(v2.get_value(), assert numpy.allclose(v2.get_value(),
numpy.dot(m.get_value(), v1.get_value()) + v2_orig) numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = g.maker.env.toposort() topo = g.maker.env.toposort()
assert len(topo)==1 assert len(topo) == 1
assert isinstance(topo[0].op, Gemv) assert isinstance(topo[0].op, Gemv)
if config.mode != 'FAST_COMPILE': if config.mode != 'FAST_COMPILE':
assert topo[0].op.inplace==True assert topo[0].op.inplace == True
# Do the same tests with a matrix with strides in both dimensions # Do the same tests with a matrix with strides in both dimensions
m.set_value( m.set_value(
...@@ -1008,40 +1143,42 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1008,40 +1143,42 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
numpy.dot(m.get_value(), v1.get_value()) + v2_orig) numpy.dot(m.get_value(), v1.get_value()) + v2_orig)
def test_gemv1(self): def test_gemv1(self):
self.t_gemv1((3,2)) self.t_gemv1((3, 2))
self.t_gemv1((0,2)) self.t_gemv1((0, 2))
self.t_gemv1((3,0)) self.t_gemv1((3, 0))
self.t_gemv1((0,0)) self.t_gemv1((0, 0))
def test_gemv2(self): def test_gemv2(self):
''' test vector2+dot(vector1,matrix) ''' ''' test vector2+dot(vector1,matrix) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
v1 = theano.shared(numpy.array(rng.uniform(size=(2,)), dtype='float32')) v1 = theano.shared(numpy.array(rng.uniform(size=(2,)),
dtype='float32'))
v2_orig = numpy.array(rng.uniform(size=(3,)), dtype='float32') v2_orig = numpy.array(rng.uniform(size=(3,)), dtype='float32')
v2 = theano.shared(v2_orig ) v2 = theano.shared(v2_orig)
m = theano.shared(numpy.array(rng.uniform(size=(2,3)), dtype='float32')) m = theano.shared(numpy.array(rng.uniform(size=(2, 3)),
dtype='float32'))
f = theano.function([], v2+theano.dot(v1,m), mode = mode_blas_opt) f = theano.function([], v2 + theano.dot(v1, m), mode=mode_blas_opt)
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), assert numpy.allclose(f(),
numpy.dot(v1.get_value(), m.get_value()) + v2.get_value()) numpy.dot(v1.get_value(), m.get_value()) + v2.get_value())
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1 assert sum(isinstance(node.op, Gemv) for node in topo) == 1
assert topo[-1].op.inplace==False assert topo[-1].op.inplace == False
#test the inplace version #test the inplace version
g = theano.function([], [], updates={v2:v2+theano.dot(v1,m)} g = theano.function([], [], updates={v2: v2 + theano.dot(v1, m)},
, mode = mode_blas_opt) mode=mode_blas_opt)
# Assert they produce the same output # Assert they produce the same output
g() g()
assert numpy.allclose(v2.get_value(), assert numpy.allclose(v2.get_value(),
numpy.dot(v1.get_value(), m.get_value()) + v2_orig) numpy.dot(v1.get_value(), m.get_value()) + v2_orig)
topo = g.maker.env.toposort() topo = g.maker.env.toposort()
assert sum(isinstance(node.op, Gemv) for node in topo)==1 assert sum(isinstance(node.op, Gemv) for node in topo) == 1
if config.mode != 'FAST_COMPILE': if config.mode != 'FAST_COMPILE':
assert topo[-1].op.inplace==True assert topo[-1].op.inplace == True
# Do the same tests with a matrix with strides in both dimensions # Do the same tests with a matrix with strides in both dimensions
m.set_value( m.set_value(
...@@ -1066,7 +1203,7 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1066,7 +1203,7 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
f = theano.function([A, x, y], z) f = theano.function([A, x, y], z)
# Matrix value # Matrix value
A_val = numpy.ones((5,3), dtype=config.floatX) A_val = numpy.ones((5, 3), dtype=config.floatX)
# Different vector length # Different vector length
ones_3 = numpy.ones(3, dtype=config.floatX) ones_3 = numpy.ones(3, dtype=config.floatX)
ones_4 = numpy.ones(4, dtype=config.floatX) ones_4 = numpy.ones(4, dtype=config.floatX)
...@@ -1090,7 +1227,7 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1090,7 +1227,7 @@ class TestGemv(TestCase, unittest_tools.TestOptimizationMixin):
def matrixmultiply(a, b): def matrixmultiply(a, b):
if len(b.shape) == 1: if len(b.shape) == 1:
b_is_vector = True b_is_vector = True
b = b[:,newaxis] b = b[:, newaxis]
else: else:
b_is_vector = False b_is_vector = False
assert a.shape[1] == b.shape[0] assert a.shape[1] == b.shape[0]
...@@ -1099,8 +1236,8 @@ def matrixmultiply(a, b): ...@@ -1099,8 +1236,8 @@ def matrixmultiply(a, b):
for j in xrange(b.shape[1]): for j in xrange(b.shape[1]):
s = 0 s = 0
for k in xrange(a.shape[1]): for k in xrange(a.shape[1]):
s += a[i,k] * b[k, j] s += a[i, k] * b[k, j]
c[i,j] = s c[i, j] = s
if b_is_vector: if b_is_vector:
c = c.reshape((a.shape[0],)) c = c.reshape((a.shape[0],))
return c return c
...@@ -1110,23 +1247,25 @@ class BaseGemv(object): ...@@ -1110,23 +1247,25 @@ class BaseGemv(object):
mode = mode_blas_opt # can be overridden with self.mode mode = mode_blas_opt # can be overridden with self.mode
shared = staticmethod(theano.shared) shared = staticmethod(theano.shared)
def get_data(self,x_stride=1,y_stride=1): def get_data(self, x_stride=1, y_stride=1):
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
mult = array(1, dtype=self.dtype) mult = array(1, dtype=self.dtype)
if self.dtype in [complex64,complex128]: if self.dtype in [complex64, complex128]:
mult = array(1 + 1j, dtype=self.dtype) mult = array(1 + 1j, dtype=self.dtype)
alpha = array(1., dtype=self.dtype) * mult alpha = array(1., dtype=self.dtype) * mult
beta = array(1., dtype=self.dtype) * mult beta = array(1., dtype=self.dtype) * mult
a = rng.randn(3,3).astype(self.dtype) * mult a = rng.randn(3, 3).astype(self.dtype) * mult
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult x = arange(shape(a)[0] * x_stride, dtype=self.dtype) * mult
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult y = arange(shape(a)[1] * y_stride, dtype=self.dtype) * mult
return alpha,beta,a,x,y return alpha, beta, a, x, y
def test_simple(self): def test_simple(self):
alpha, beta, a, x, y = [ self.shared(value) for value in self.get_data() ] alpha, beta, a, x, y = [self.shared(value)
desired_oy = alpha.get_value() * matrixmultiply(a.get_value(),x.get_value()) + beta.get_value() * y.get_value() for value in self.get_data()]
desired_oy = alpha.get_value() * matrixmultiply(a.
get_value(), x.get_value()) + beta.get_value() * y.get_value()
oy = alpha * T.dot(a,x) + beta * y oy = alpha * T.dot(a, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1146,7 +1285,7 @@ class BaseGemv(object): ...@@ -1146,7 +1285,7 @@ class BaseGemv(object):
desired_oy = matrixmultiply(a_v, x_v) desired_oy = matrixmultiply(a_v, x_v)
oy = T.dot(a,x) oy = T.dot(a, x)
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1155,15 +1294,15 @@ class BaseGemv(object): ...@@ -1155,15 +1294,15 @@ class BaseGemv(object):
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
def test_simple_transpose(self): def test_simple_transpose(self):
vs = self.get_data() vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v desired_oy = alpha_v * matrixmultiply(transpose(a_v),
x_v) + beta_v * y_v
oy = alpha * T.dot(a.T,x)+beta*y oy = alpha * T.dot(a.T, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1173,13 +1312,13 @@ class BaseGemv(object): ...@@ -1173,13 +1312,13 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
def test_x_stride(self): def test_x_stride(self):
vs = self.get_data(x_stride = 2) vs = self.get_data(x_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(a_v,x_v[::2])+beta_v*y_v desired_oy = alpha_v * matrixmultiply(a_v, x_v[::2]) + beta_v * y_v
oy = alpha * T.dot(a,x[::2])+beta*y oy = alpha * T.dot(a, x[::2]) + beta * y
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1189,13 +1328,14 @@ class BaseGemv(object): ...@@ -1189,13 +1328,14 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
def test_x_stride_transpose(self): def test_x_stride_transpose(self):
vs = self.get_data(x_stride = 2) vs = self.get_data(x_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v[::2])+beta_v*y_v desired_oy = alpha_v * matrixmultiply(transpose(a_v), x_v[::
2]) + beta_v * y_v
oy = alpha * T.dot(a.T,x[::2])+beta*y oy = alpha * T.dot(a.T, x[::2]) + beta * y
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1205,13 +1345,13 @@ class BaseGemv(object): ...@@ -1205,13 +1345,13 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
def test_y_stride(self): def test_y_stride(self):
vs = self.get_data(y_stride = 2) vs = self.get_data(y_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(a_v,x_v)+beta_v*y_v[::2] desired_oy = alpha_v * matrixmultiply(a_v, x_v) + beta_v * y_v[::2]
oy = alpha * T.dot(a,x)+beta*y[::2] oy = alpha * T.dot(a, x) + beta * y[::2]
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1221,13 +1361,14 @@ class BaseGemv(object): ...@@ -1221,13 +1361,14 @@ class BaseGemv(object):
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
def test_y_stride_transpose(self): def test_y_stride_transpose(self):
vs = self.get_data(y_stride = 2) vs = self.get_data(y_stride=2)
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v[::2] desired_oy = alpha_v * matrixmultiply(transpose(a_v),
x_v) + beta_v * y_v[::2]
oy = alpha * T.dot(a.T,x)+beta*y[::2] oy = alpha * T.dot(a.T, x) + beta * y[::2]
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1239,15 +1380,16 @@ class BaseGemv(object): ...@@ -1239,15 +1380,16 @@ class BaseGemv(object):
def test_a_strides(self): def test_a_strides(self):
vs = self.get_data() vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
a_v = a_v[::-1, ::-1] a_v = a_v[::-1, ::-1]
a.set_value( a.set_value(
a.get_value(borrow=True, return_internal_type=True)[::-1, ::-1], a.get_value(borrow=True,
return_internal_type=True)[::-1, ::-1],
borrow=True) borrow=True)
desired_oy = alpha_v * matrixmultiply(a_v,x_v)+beta_v*y_v desired_oy = alpha_v * matrixmultiply(a_v, x_v) + beta_v * y_v
oy = alpha * T.dot(a,x)+beta*y oy = alpha * T.dot(a, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1259,15 +1401,17 @@ class BaseGemv(object): ...@@ -1259,15 +1401,17 @@ class BaseGemv(object):
def test_a_strides_transpose(self): def test_a_strides_transpose(self):
vs = self.get_data() vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ self.shared(v) for v in vs ] alpha, beta, a, x, y = [self.shared(v) for v in vs]
a_v = a_v[::-1, ::-1] a_v = a_v[::-1, ::-1]
a.set_value( a.set_value(
a.get_value(borrow=True, return_internal_type=True)[::-1, ::-1], a.get_value(borrow=True,
return_internal_type=True)[::-1, ::-1],
borrow=True) borrow=True)
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v desired_oy = alpha_v * matrixmultiply(transpose(a_v),
x_v) + beta_v * y_v
oy = alpha * T.dot(a.T,x)+beta*y oy = alpha * T.dot(a.T, x) + beta * y
oy_func = theano.function([], oy, mode=self.mode) oy_func = theano.function([], oy, mode=self.mode)
...@@ -1332,6 +1476,7 @@ class TestDgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin): ...@@ -1332,6 +1476,7 @@ class TestDgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
## Tests for Ger ## Tests for Ger
############################################################################### ###############################################################################
class TestGer_make_node(TestCase): class TestGer_make_node(TestCase):
def setUp(self): def setUp(self):
self.iv = T.tensor(dtype='int32', broadcastable=(False,)) self.iv = T.tensor(dtype='int32', broadcastable=(False,))
...@@ -1439,19 +1584,21 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1439,19 +1584,21 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
""" test local_gemm_to_ger opt""" """ test local_gemm_to_ger opt"""
assert T.blas.local_gemm_to_ger.transform( assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace( gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'), self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(0)).owner) self.y.dimshuffle('x', 0), self.b(0)).owner)
def test_b_1_triggers_ger(self): def test_b_1_triggers_ger(self):
""" test local_gemm_to_ger opt""" """ test local_gemm_to_ger opt"""
assert T.blas.local_gemm_to_ger.transform( assert T.blas.local_gemm_to_ger.transform(
gemm_no_inplace( gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'), self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1)).owner) self.y.dimshuffle('x', 0), self.b(1)).owner)
def test_b_other_does_not_triggers_ger(self): def test_b_other_does_not_triggers_ger(self):
""" test local_gemm_to_ger opt""" """ test local_gemm_to_ger opt"""
assert not T.blas.local_gemm_to_ger.transform( assert not T.blas.local_gemm_to_ger.transform(
gemm_no_inplace( gemm_no_inplace(
self.A, self.a, self.x.dimshuffle(0,'x'), self.A, self.a, self.x.dimshuffle(0, 'x'),
self.y.dimshuffle('x', 0), self.b(1.5)).owner) self.y.dimshuffle('x', 0), self.b(1.5)).owner)
def test_outer(self): def test_outer(self):
...@@ -1555,6 +1702,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1555,6 +1702,7 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
f(numpy.random.rand(4).astype(self.dtype), f(numpy.random.rand(4).astype(self.dtype),
numpy.random.rand(5).astype(self.dtype)) numpy.random.rand(5).astype(self.dtype))
class TestBlasStrides(TestCase): class TestBlasStrides(TestCase):
dtype = 'float64' dtype = 'float64'
shared = staticmethod(tensor._shared) shared = staticmethod(tensor._shared)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论