提交 b6b2c608 authored 作者: James Bergstra's avatar James Bergstra

code in a mess, but gemm-optimization works on more systematic test cases…

code in a mess, but gemm-optimization works on more systematic test cases including josephs NAACL graph
上级 43291f46
...@@ -63,11 +63,20 @@ def register_optimizer(name, opt): ...@@ -63,11 +63,20 @@ def register_optimizer(name, opt):
raise ValueError('Optimizer name already taken: %s' % name) raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
class AddDestroyHandler(gof.Optimizer):
def apply(self, env):
pass
def add_requirements(self, env):
super(AddDestroyHandler, self).add_requirements(env)
env.extend(gof.DestroyHandler())
optdb = gof.SequenceDB() optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), 0, 'fast_run', 'fast_compile') optdb.register('merge1', gof.MergeOptimizer(), 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), 1, 'fast_run') optdb.register('canonicalize', gof.EquilibriumDB(), 1, 'fast_run')
optdb.register('specialize', gof.EquilibriumDB(), 2, 'fast_run') optdb.register('specialize', gof.EquilibriumDB(), 2, 'fast_run')
optdb.register('merge2', gof.EquilibriumDB(), 100, 'fast_run') optdb.register('merge2', gof.EquilibriumDB(), 49, 'fast_run')
optdb.register('add_destroy_handler', AddDestroyHandler(), 49.5, 'fast_run', 'inplace')
optdb.register('merge3', gof.EquilibriumDB(), 100, 'fast_run')
class Mode(object): class Mode(object):
......
...@@ -20,15 +20,14 @@ from link import \ ...@@ -20,15 +20,14 @@ from link import \
from op import \ from op import \
Op Op
from opt import \ from opt import (Optimizer, optimizer, SeqOptimizer,
Optimizer, optimizer, SeqOptimizer, \ MergeOptimizer, MergeOptMerge,
MergeOptimizer, MergeOptMerge, \ LocalOptimizer, local_optimizer, LocalOptGroup,
LocalOptimizer, local_optimizer, LocalOptGroup, \ OpSub, OpRemove, PatternSub,
OpSub, OpRemove, PatternSub, \ NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, \ keep_going, warn,
keep_going, warn, \ InplaceOptimizer, PureThenInplaceOptimizer,
InplaceOptimizer, PureThenInplaceOptimizer OpKeyOptimizer)
#LocalOpKeyOptGroup, OpKeyOptimizer
from optdb import \ from optdb import \
DB, Query, \ DB, Query, \
......
...@@ -265,6 +265,11 @@ class LocalOptimizer(object): ...@@ -265,6 +265,11 @@ class LocalOptimizer(object):
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
def add_requirements(self, env):
"""If this local optimization wants to add some requirements to the env,
This is the place to do it."""
env.extend(toolbox.ReplaceValidate())
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME""" """WRITEME"""
...@@ -273,8 +278,6 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -273,8 +278,6 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
self._tracks = tracks self._tracks = tracks
def tracks(self): def tracks(self):
return self._tracks return self._tracks
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
def __str__(self): def __str__(self):
return getattr(self, 'name', '<FromFunctionLocalOptimizer instance>') return getattr(self, 'name', '<FromFunctionLocalOptimizer instance>')
...@@ -551,7 +554,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -551,7 +554,7 @@ class NavigatorOptimizer(Optimizer):
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
""" """
:param local_opt: a LocalOptimizer to apply over a Env. :param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param ignore_newtrees: :param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization - True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization - False: new subgraphs returned by an optimization is a candidate for optimization
...@@ -617,6 +620,24 @@ class NavigatorOptimizer(Optimizer): ...@@ -617,6 +620,24 @@ class NavigatorOptimizer(Optimizer):
env.remove_feature(u) env.remove_feature(u)
def process_node(self, env, node, lopt = None): def process_node(self, env, node, lopt = None):
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
return either False or a list of Results that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is successful, and this
function returns True.
If there are no replacement candidates or the env rejects the replacements, this
function returns False.
:param env: an Env
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for how to compute
node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
"""
lopt = lopt or self.local_opt lopt = lopt or self.local_opt
try: try:
replacements = lopt.transform(node) replacements = lopt.transform(node)
...@@ -633,23 +654,21 @@ class NavigatorOptimizer(Optimizer): ...@@ -633,23 +654,21 @@ class NavigatorOptimizer(Optimizer):
env.replace_all_validate(repl_pairs) env.replace_all_validate(repl_pairs)
return True return True
except Exception, e: except Exception, e:
# This means the replacements were rejected by the env.
#
# This is not supposed to happen. The default failure_callback will print a
# traceback as a warning.
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs) self.failure_callback(e, self, repl_pairs)
#DEBUG DONT PUSH
#print lopt
#print dir(lopt)
#raise
#END
return False return False
else: else:
raise raise
def add_requirements(self, env): def add_requirements(self, env):
super(NavigatorOptimizer, self).add_requirements(env)
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
if self.local_opt:
self.local_opt.add_requirements(env)
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
...@@ -722,7 +741,7 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -722,7 +741,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
- NodeFinder - NodeFinder
- ReplaceValidate - ReplaceValidate
""" """
NavigatorOptimizer.add_requirements(self, env) super(OpKeyOptimizer, self).add_requirements(env)
env.extend(toolbox.NodeFinder()) env.extend(toolbox.NodeFinder())
......
...@@ -13,6 +13,8 @@ class DB(object): ...@@ -13,6 +13,8 @@ class DB(object):
def __init__(self): def __init__(self):
self.__db__ = defaultdict(set) self.__db__ = defaultdict(set)
self._names = set() self._names = set()
self.name = None #will be reset by register
#(via obj.name by the thing doing the registering)
def register(self, name, obj, *tags): def register(self, name, obj, *tags):
# N.B. obj is not an instance of class Optimizer. # N.B. obj is not an instance of class Optimizer.
...@@ -21,6 +23,8 @@ class DB(object): ...@@ -21,6 +23,8 @@ class DB(object):
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)): if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise Exception('wtf', obj) raise Exception('wtf', obj)
if self.name is not None:
tags = tags + (self.name,)
obj.name = name obj.name = name
if name in self.__db__: if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name) raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
...@@ -118,9 +122,10 @@ class EquilibriumDB(DB): ...@@ -118,9 +122,10 @@ class EquilibriumDB(DB):
class SequenceDB(DB): class SequenceDB(DB):
def __init__(self): def __init__(self, failure_callback = opt.warn):
super(SequenceDB, self).__init__() super(SequenceDB, self).__init__()
self.__priority__ = {} self.__priority__ = {}
self.failure_callback = failure_callback
def register(self, name, obj, priority, *tags): def register(self, name, obj, priority, *tags):
super(SequenceDB, self).register(name, obj, *tags) super(SequenceDB, self).register(name, obj, *tags)
...@@ -130,6 +135,6 @@ class SequenceDB(DB): ...@@ -130,6 +135,6 @@ class SequenceDB(DB):
opts = super(SequenceDB, self).query(*tags, **kwtags) opts = super(SequenceDB, self).query(*tags, **kwtags)
opts = list(opts) opts = list(opts)
opts.sort(key = lambda obj: self.__priority__[obj.name]) opts.sort(key = lambda obj: self.__priority__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = opt.warn) return opt.SeqOptimizer(opts, failure_callback = self.failure_callback)
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions""" """Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import os, sys import os, sys, traceback
import numpy import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub, from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
InplaceOptimizer, SeqOptimizer, warn, local_optimizer) SeqOptimizer, warn, local_optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError)
from ..printing import pprint, FunctionPrinter from ..printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer from .opt import register_specialize, out2in, insert_inplace_optimizer
import basic as T import basic as T
from ..tensor import as_tensor
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
from .. import compile #to register the optimizer built by this file from .. import compile #to register the optimizer built by this file
from .blas_headers import cblas_header_text, blas_header_text from .blas_headers import cblas_header_text, blas_header_text
JOSEPHS_BUG_SOLVED = False
@utils.memoize @utils.memoize
def ldflags(): def ldflags():
"""Return a list of libraries against which an Op's object file should be """Return a list of libraries against which an Op's object file should be
...@@ -270,7 +267,7 @@ class Gemm(GemmRelated): ...@@ -270,7 +267,7 @@ class Gemm(GemmRelated):
E_z_uniq = 'argument z aliased to x or y' E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]} destroy_map = {0: [0]}
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(as_tensor, inputs) inputs = map(T.as_tensor, inputs)
if len(inputs) != 5: if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs))) raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
z, a, x, y, b = inputs z, a, x, y, b = inputs
...@@ -348,87 +345,110 @@ class Gemm(GemmRelated): ...@@ -348,87 +345,110 @@ class Gemm(GemmRelated):
#undef REAL #undef REAL
""" """
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
gemm = Gemm() gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm')) pprint.assign(gemm, FunctionPrinter('gemm'))
class Dot22(GemmRelated): def res_is_a(node, op, maxclients=None):
"""Compute a matrix-matrix product. return node.owner \
This is a specialization of the more general Dot() and node.owner.op == op \
and (len(node.clients) <= maxclients if maxclients is not None else True)
class GemmLocalOptimizer(LocalOptimizer):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
""" """
def make_node(self, x, y):
assert _is_real_matrix(x)
assert y.type == x.type #makes sure y is a matrix
bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )): def __init__(self):
try: super(LocalOptimizer, self).__init__()
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
setup_z_Nz_Sz = """ def op_key(self):
if ((NULL == %(_z)s) return [T.add, T.sub]
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
""" def add_requirements(self, env):
check_ab_double_or_float = "" super(GemmLocalOptimizer,self).add_requirements(env)
case_float_ab_constants = """ env.extend(DestroyHandler())
float a = 1.0;
float b = 0.0;
"""
case_double_ab_constants = """
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub):
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
_dot22 = Dot22()
@local_optimizer([T.dot]) def transform(self, node):
def local_dot_to_dot22(node): _as_scalar, _is_real_matrix, _as_isolated_scalar_times_matrix, beta_L_plus_alpha_M\
if node.op == T.dot: = (GemmLocalOptimizer._as_scalar,
x,y = node.inputs GemmLocalOptimizer._is_real_matrix,
if _is_real_matrix(x) and y.type == x.type: GemmLocalOptimizer._as_isolated_scalar_times_matrix,
return [_dot22(*node.inputs)] GemmLocalOptimizer.beta_L_plus_alpha_M)
if node.op == T.sub:
L, R = node.inputs
if not _is_real_matrix(L):
return False
if not _is_real_matrix(R):
return False
tmp = _as_isolated_scalar_times_matrix(L)
try:
sL, mL = tmp
except:
sL, mL = 1.0, L
tmp = _as_isolated_scalar_times_matrix(R)
try:
sR, mR = tmp
except:
sR, mR = 1.0, R
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
if node.op == T.add:
sM_list = []
other_inputs = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
if tmp:
sM_list.append(tmp)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
else:
other_inputs.append(input)
if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
#we turned the two candidates into a gemm
# now we have to add the other_inputs and return the replacement graph
if other_inputs:
return [T.add(*(other_inputs + gemm_of_sM_list))]
else:
return gemm_of_sM_list
else: else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + gemm_of_sM_list + other_inputs))]
return False return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_dot_to_dot22)
def _is_a(node, op, maxclients=None): @staticmethod
return node.owner \ def failure_callback(exc, nav, repl_pairs):
and node.owner.op == op \ """WRITEME"""
and len(node.clients) <= maxclients if maxclients is not None else True if not isinstance(exc, InconsistencyError):
traceback.print_exc()
else:
print 'GEMM caused cycle, forget it.'
def _as_scalar(res): @staticmethod
def _as_scalar(res):
"""Return None or a TensorResult whose type is in T.float_scalar_types""" """Return None or a TensorResult whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle): if res.owner and isinstance(res.owner.op, T.DimShuffle):
return _as_scalar(res.owner.inputs[0]) return GemmLocalOptimizer._as_scalar(res.owner.inputs[0])
elif res.type in T.float_scalar_types: elif res.type in T.float_scalar_types:
return res return res
elif isinstance(res, T.Constant) and res.data.size == 1: elif isinstance(res, T.Constant) and res.data.size == 1:
...@@ -436,13 +456,20 @@ def _as_scalar(res): ...@@ -436,13 +456,20 @@ def _as_scalar(res):
else: else:
return None return None
def _is_real_matrix(res): @staticmethod
def _is_real_matrix(res):
return res.type in T.float_matrix_types \ return res.type in T.float_matrix_types \
and res.broadcastable[0] == False \ and res.broadcastable[0] == False \
and res.broadcastable[1] == False #cope with tuple vs. list and res.broadcastable[1] == False #cope with tuple vs. list
def _as_isolated_scalar_times_matrix(res): @staticmethod
if _is_a(res, T.mul, 1): def _as_isolated_scalar_times_matrix(res):
_as_scalar, _is_real_matrix, _as_isolated_scalar_times_matrix, beta_L_plus_alpha_M\
= (GemmLocalOptimizer._as_scalar,
GemmLocalOptimizer._is_real_matrix,
GemmLocalOptimizer._as_isolated_scalar_times_matrix,
GemmLocalOptimizer.beta_L_plus_alpha_M)
if res_is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2: if len(res.owner.inputs) == 2:
L, R = res.owner.inputs L, R = res.owner.inputs
sL = _as_scalar(L) sL = _as_scalar(L)
...@@ -466,105 +493,122 @@ def _as_isolated_scalar_times_matrix(res): ...@@ -466,105 +493,122 @@ def _as_isolated_scalar_times_matrix(res):
rval = (T.mul(*scalars), matrices[0]) rval = (T.mul(*scalars), matrices[0])
return rval return rval
@staticmethod
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M) #EXPRESSION: (beta * L) + (alpha * M)
if _is_a(M, _dot22, 1): if True:
if res_is_a(L, T.sqrt):
print 'CLIENTS OF L', L, L.clients
if res_is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)] rval = [gemm(L, alpha, Ml, Mr, beta)]
print 'GEMM 0', rval, beta, L, alpha, M
return rval return rval
if _is_a(M, gemm, 1): if False and res_is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b))) #EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v) #EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L #print 'GEMM', G, L
if _is_a(G, _dot22, 1): if res_is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b))) #EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v))))) #EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#print 'GEMM 1', G, L
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)] rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
print 'GEMM 1', rval
return rval return rval
elif G is L: if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)] rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
#print 'GEMM 2', rval print 'GEMM 2', rval
return rval return rval
elif 1.0 != alpha: if (1.0 != alpha):
#at the very least, move the alpha inside the gemm #at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)] rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
#print 'GEMM 3', G, L print 'GEMM 3', rval
return rval return rval
if recurse_flip: if recurse_flip:
return beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False) return GemmLocalOptimizer.beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
else: else:
return False return False
@local_optimizer([T.sub]) #I think that three passes should suffice to catch all the GEMMs.
def local_sub_to_gemm(node): # TODO: This could be an equilibriumOptmizer, but I don't know how to combine an OpKeyOptimizer and
if node.op == T.sub: # an EquilibriumOptimizer.
L, R = node.inputs compile.optdb.register('inplace_gemm_0', OpKeyOptimizer(GemmLocalOptimizer(),
if not _is_real_matrix(L): failure_callback=GemmLocalOptimizer.failure_callback), 70.00, 'fast_run', 'inplace')
return False compile.optdb.register('inplace_gemm_1', OpKeyOptimizer(GemmLocalOptimizer(),
if not _is_real_matrix(R): failure_callback=GemmLocalOptimizer.failure_callback), 70.01, 'fast_run', 'inplace')
return False compile.optdb.register('inplace_gemm_2', OpKeyOptimizer(GemmLocalOptimizer(),
failure_callback=GemmLocalOptimizer.failure_callback), 70.02, 'fast_run', 'inplace')
tmp = _as_isolated_scalar_times_matrix(L) class Dot22(GemmRelated):
try: """Compute a matrix-matrix product.
sL, mL = tmp This is a specialization of the more general Dot()
except: """
sL, mL = 1.0, L def make_node(self, x, y):
assert GemmLocalOptimizer._is_real_matrix(x)
assert y.type == x.type #makes sure y is a matrix
bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
tmp = _as_isolated_scalar_times_matrix(R) def perform(self, node, (x, y), (z, )):
try: try:
sR, mR = tmp z[0] = numpy.asarray(numpy.dot(x, y))
except: except ValueError, e:
sR, mR = 1.0, R # The error raised by numpy has no shape information, we mean to add that
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR) e.args = e.args + (x.shape, y.shape)
return rval raise
return False def __str__(self):
if JOSEPHS_BUG_SOLVED: return "_dot22"
register_specialize(local_sub_to_gemm)
@local_optimizer([T.add]) setup_z_Nz_Sz = """
def local_add_to_gemm(node): if ((NULL == %(_z)s)
"""This is a massive beast for recognizing all the ways that a subtraction could be || (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
replaced by a GEMM || (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
""" """
if node.op == T.add: check_ab_double_or_float = ""
sM_list = [] case_float_ab_constants = """
for input in node.inputs: float a = 1.0;
tmp = _as_isolated_scalar_times_matrix(input) float b = 0.0;
if tmp: """
sM_list.append(tmp) case_double_ab_constants = """
elif _is_real_matrix(input): double a = 1.0;
sM_list.append((1.0, input)) double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
_dot22 = Dot22()
if len(sM_list) == 2: @local_optimizer([T.dot])
sL, mL = sM_list[0] def local_dot_to_dot22(node):
sR, mR = sM_list[1] if node.op == T.dot:
return beta_L_plus_alpha_M(sL, mL, sR, mR) x,y = node.inputs
if GemmLocalOptimizer._is_real_matrix(x) and y.type == x.type:
return [_dot22(*node.inputs)]
else: else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
rval = beta_L_plus_alpha_M(sL, mL, sR, mR)
if rval:
assert len(rval) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))]
return False return False
if JOSEPHS_BUG_SOLVED: register_specialize(local_dot_to_dot22)
register_specialize(local_add_to_gemm)
...@@ -316,7 +316,7 @@ class Elemwise(Op): ...@@ -316,7 +316,7 @@ class Elemwise(Op):
scalars scalars
* inplace_pattern: a dictionary that maps the index of an output to the * inplace_pattern: a dictionary that maps the index of an output to the
index of an input so the output is calculated inplace using index of an input so the output is calculated inplace using
the input's storage. the input's storage. (Just like destroymap, but without the lists.)
""" """
self.name = name self.name = name
self.scalar_op = scalar_op self.scalar_op = scalar_op
...@@ -357,16 +357,21 @@ class Elemwise(Op): ...@@ -357,16 +357,21 @@ class Elemwise(Op):
args.append(input) args.append(input)
else: else:
# TODO: use LComplete instead # TODO: use LComplete instead
args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length), inplace = True)(input)) args.append(DimShuffle(
input.type.broadcastable,
['x']*difference + range(length),
inplace = True)(input))
inputs = args inputs = args
# # Following conditions should always be true? #HERE: all the broadcast dims have the same length now
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
#cleverness: we iterate over the first, second, third broadcast flag of all inputs in
#parallel... the all() gives us each output broadcastable bit in turn.
#it is multiplied by nout because Elemwise supports multiple outputs (nout of them)
out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout
#inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern inplace_pattern = self.inplace_pattern
if inplace_pattern: if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items(): for overwriter, overwritten in inplace_pattern.items():
...@@ -374,21 +379,32 @@ class Elemwise(Op): ...@@ -374,21 +379,32 @@ class Elemwise(Op):
if ib and not ob: if ib and not ob:
raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.") raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.")
out_dtypes = [o.type.dtype for o in shadow.outputs] out_dtypes = [o.type.dtype for o in shadow.outputs]
if any(inputs[i].type.dtype != out_dtypes[o] for i, o in inplace_pattern.items()): if any(inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()):
raise TypeError("Cannot do an inplace operation on incompatible data types.", [i.type.dtype for i in inputs], out_dtypes) raise TypeError("Cannot do an inplace operation on incompatible data types.",
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern))
outputs = [Tensor(dtype = dtype, broadcastable = broadcastable)() for dtype, broadcastable in zip(out_dtypes, out_broadcastables)] outputs = [Tensor(dtype = dtype, broadcastable = broadcastable)() for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern if type(self) == type(other):
items = self.inplace_pattern.items()
other_items = other.inplace_pattern.items()
items.sort()
other_items.sort()
return self.scalar_op == other.scalar_op and items == other_items
return False
def __hash__(self): def __hash__(self):
return hash(self.scalar_op) ^ hash(tuple(self.inplace_pattern.items())) items = self.inplace_pattern.items()
items.sort()
return hash(self.scalar_op) ^ hash(tuple(items))
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
if self.inplace_pattern: if self.inplace_pattern:
return "Elemwise{%s}%s" % (self.scalar_op, str(self.inplace_pattern)) items = self.inplace_pattern.items()
items.sort()
return "Elemwise{%s}%s" % (self.scalar_op, str(items))
else: else:
return "Elemwise{%s}" % (self.scalar_op) return "Elemwise{%s}" % (self.scalar_op)
else: else:
...@@ -467,6 +483,7 @@ class Elemwise(Op): ...@@ -467,6 +483,7 @@ class Elemwise(Op):
storage[0] = odat storage[0] = odat
else: else:
for i, (output, storage) in enumerate(zip(node.outputs, output_storage)): for i, (output, storage) in enumerate(zip(node.outputs, output_storage)):
#i is an output idx
if i in self.inplace_pattern: if i in self.inplace_pattern:
odat = inputs[self.inplace_pattern[i]] odat = inputs[self.inplace_pattern[i]]
else: else:
...@@ -500,7 +517,7 @@ class Elemwise(Op): ...@@ -500,7 +517,7 @@ class Elemwise(Op):
defines = "" defines = ""
undefs = "" undefs = ""
dmap = dict([(node.outputs[i], [node.inputs[o]]) for i, o in self.inplace_pattern.items()]) dmap = dict([(node.outputs[o], [node.inputs[i]]) for o, i in self.inplace_pattern.items()])
idtypes = [input.type.dtype_specs()[1] for input in inputs] idtypes = [input.type.dtype_specs()[1] for input in inputs]
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .. import gof from .. import gof
from ..gof import opt from ..gof import opt, InconsistencyError
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from .. import scalar
import basic as T import basic as T
...@@ -32,7 +32,8 @@ def in2out(*local_opts, **kwargs): ...@@ -32,7 +32,8 @@ def in2out(*local_opts, **kwargs):
def _insert_inplace_optimizer(env): @gof.optimizer
def insert_inplace_optimizer(env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -59,17 +60,18 @@ def _insert_inplace_optimizer(env): ...@@ -59,17 +60,18 @@ def _insert_inplace_optimizer(env):
try: try:
new = Elemwise( new = Elemwise(
op.scalar_op.__class__( op.scalar_op.__class__(
scalar.transfer_type(*[inplace_pattern.get(i, None) for i in xrange(len(node.outputs))])), scalar.transfer_type(
*[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))])),
inplace_pattern).make_node(*node.inputs) inplace_pattern).make_node(*node.inputs)
env.replace_all_validate(zip(node.outputs, new.outputs)) env.replace_all_validate(zip(node.outputs, new.outputs))
except Exception, e: except (ValueError, TypeError, InconsistencyError), e:
continue continue
candidate_inputs.remove(candidate_input) candidate_inputs.remove(candidate_input)
node = new node = new
baseline = inplace_pattern baseline = inplace_pattern
break break
insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
...@@ -310,7 +312,7 @@ def local_fill_cut(node): ...@@ -310,7 +312,7 @@ def local_fill_cut(node):
register_canonicalize(local_fill_cut) register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' ) #register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' ) #DEBUG
@gof.local_optimizer([None, T.fill]) @gof.local_optimizer([None, T.fill])
def local_fill_sink(node): def local_fill_sink(node):
...@@ -650,38 +652,6 @@ def local_mul_specialize(node): ...@@ -650,38 +652,6 @@ def local_mul_specialize(node):
return False return False
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
if 0: #TODO: replace this with a c version of any InplaceDimShuffle
class _TransposeInplace(T.Op):
view_map = {0: [0]}
def make_node(self, input):
return T.Apply(self, [input],
[T.tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "_TransposeInplace"
_transpose_inplace = _TransposeInplace()
@gof.local_optimizer([T.DimShuffle([False,False],[1,0],inplace=True)])
def local_dimshuffle_transposeinplace(node):
if node.op == T.DimShuffle([False,False],[1,0],inplace=True):
return [_transpose_inplace(node.inputs[0])]
return False
register_specialize(local_dimshuffle_transposeinplace)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer') register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
...@@ -844,287 +814,3 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y ...@@ -844,287 +814,3 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y
register_canonicalize(local_transposed_dot, name='local_transposed_dot') register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# def _math_optimizer():
# pass_1 = in2out(local_fill_sink)
# pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
# pass_3 = out2in(local_subtensor_make_vector, local_fill_cut)
# canonizer = in2out(local_add_canonizer,
# local_mul_canonizer,
# local_fill_sink)
# pass_4 = out2in(local_greedy_distributor)
# return gof.SeqOptimizer(pass_1,
# pass_2,
# pass_3,
# neg_to_mul,
# canonizer,
# pass_4,
# mul_to_neg)
# math_optimizer = _math_optimizer()
# compile.register_optimizer('math',
# gof.MergeOptMerge(
# gof.PureThenInplaceOptimizer(
# math_optimizer,
# inplace_optimizer)))
# compile.register_mode('SANITY_CHECK', compile.Mode('c&py', 'math'))
# compile.register_mode('FAST_RUN', compile.Mode('c|py', 'math'))
# compile.register_mode('EXPENSIVE_OPTIMIZATIONS', compile.Mode('c|py', 'math'))
# @gof.local_optimizer
# def local_clique_fusion(node):
# aaaaaaaaaaaaaaaaaaaaaaa
# def find_cliques(env, through_broadcast = False):
# """
# Usage: find_cliques(env, through_broadcast = False)
# Returns a list of pairs where each pair contains a list
# of inputs and a list of outputs such that Env(inputs, outputs)
# contains nothing but Broadcast Ops.
# If through_broadcast is False, the cliques will only be
# allowed to broadcast over the inputs, which means, for
# example, that vector operations will not be mixed with
# matrix operations.
# """
# def seek_from(r):
# # walks through the graph until it encounters a
# # non-Broadcast operation or (if through_broadcast
# # is False) a Result which needs to be broadcasted.
# op = r.owner
# if env.edge(r) \
# or not isinstance(op, Broadcast) \
# or len(op.outputs) > 1:
# # todo: handle multiple-output broadcast ops
# # (needs to update the clique's outputs)
# return None
# ret = set()
# if not through_broadcast:
# # check each dimension over all the inputs - if the broadcastable
# # fields are not all 0 or all 1 for a particular dimension, then
# # broadcasting will be performed along it on the inputs where the
# # value is 1 and we will stop.
# if any(any(bc) and not all(bc)
# for bc in zip(*[input.broadcastable for input in op.inputs])):
# ret.update(op.inputs)
# return ret
# for input in op.inputs:
# res = seek_from(input)
# if res is None:
# # input is a leaf of our search
# ret.add(input)
# else:
# ret.update(res)
# return ret
# cliques = []
# def find_cliques_helper(r):
# if env.edge(r):
# return
# clique_inputs = seek_from(r)
# if clique_inputs is None:
# # Not in a clique, keep going
# op = r.owner
# if op is not None:
# for input in op.inputs:
# find_cliques_helper(input)
# else:
# # We found a clique, add it to the list and
# # jump to the leaves.
# cliques.append((clique_inputs, [r]))
# for input in clique_inputs:
# find_cliques_helper(input)
# for output in env.outputs:
# find_cliques_helper(output)
# # todo: merge the cliques if possible
# return cliques
# class CliqueOptimizer(opt.Optimizer):
# """
# Usage: CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = False).optimize(env)
# Finds cliques of Broadcast operations in the env and does either
# or both of two things:
# * Apply scalar_optimizer on the clique as if the clique was a
# group of scalar operations. scalar_optimizer can be any optimization
# which applies on scalars. If it is None, no optimization is done.
# * Replace the clique with a single Op, optimized to perform the
# computations properly. If make_composite is False, no such replacement
# is done.
# Note: it is recommended to run the lift_dimshuffle optimization before
# this one.
# """
# def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
# self.through_broadcast = through_broadcast
# self.scalar_optimizer = scalar_optimizer
# self.make_composite = make_composite
# def apply(self, env):
# if self.scalar_optimizer is None and not self.make_composite:
# # there's nothing to do with the cliques...
# return
# cliques = find_cliques(env, self.through_broadcast)
# opt = self.scalar_optimizer
# def build_scalar_clique(r, env, equiv):
# # Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# # structure and equivalent operations. equiv contains the mapping.
# if r in equiv:
# return equiv[r]
# op = r.owner
# if env.edge(r):
# # For each leave we make a Scalar of the corresponding dtype
# s = scalar.Scalar(dtype = r.dtype)
# _r = r
# if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
# _r = r.owner.inputs[0]
# if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
# and _r.broadcastable == ():
# # If we have a constant tensor we map it to a constant scalar.
# s.data = _r.data
# s.constant = True
# equiv[r] = s
# return s
# s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
# equiv[op] = s_op
# for output, s_output in zip(op.outputs, s_op.outputs):
# equiv[output] = s_output
# return equiv[r]
# for c_in, c_out in cliques:
# equiv = dict()
# g = Env(c_in, c_out)
# for output in c_out:
# build_scalar_clique(output, g, equiv)
# s_g = Env([equiv[r] for r in g.inputs],
# [equiv[r] for r in g.outputs])
# if opt is not None:
# equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
# for k, v in equiv.items():
# equiv2[v] = k
# def transform(op, equiv):
# # We get a scalar op and we return an equivalent op on tensors.
# return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
# s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
# opt.optimize(s_g)
# if self.make_composite:
# def follow_inplace(r):
# # Tries to find the earliest r2 in g such that r destroys r2
# # If no such r2 is found, returns None
# op = r.owner
# if op is None or r in g.inputs or r in g.orphans():
# return None
# assert isinstance(op, Broadcast)
# destroyed = op.destroy_map().get(r, None)
# if destroyed is None:
# return None
# else:
# r2 = destroyed[0]
# ret = follow_inplace(r2)
# if ret is None:
# return r2
# else:
# return ret
# inplace_pattern = {}
# for i, output in enumerate(g.outputs):
# destroyed = follow_inplace(output)
# if destroyed is not None and destroyed in g.inputs:
# # we transfer the inplace operation only if it is
# # an input that is destroyed
# inplace_pattern[i] = g.inputs.index(destroyed)
# C = scalar.composite(s_g.inputs, s_g.outputs)
# ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
# env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
# def sync_to(target, equiv, transform):
# """
# Usage: sync_to(target, equiv, transform)
# * target: an Env
# * equiv: a dictionary that maps results and ops to results and ops
# in target
# * transform: a function that takes (op, equiv) as inputs and
# returns a new op.
# Returns a Feature that can be added to an Env and mirrors all
# modifications to that env with modifications to the target env.
# """
# class Synchronize(gof.Listener, gof.Constraint):
# def __init__(self, source):
# self.source = source
# self.target = target
# self.equiv = equiv
# self.transform = transform
# self.inconsistencies = []
# def on_import(self, op1):
# if op1 not in self.equiv:
# op2 = self.transform(op1, self.equiv)
# self.equiv[op1] = op2
# for o1, o2 in zip(op1.outputs, op2.outputs):
# self.equiv[o1] = o2
# def on_prune(self, op1):
# if op1 in self.equiv:
# op2 = self.equiv[op1]
# del self.equiv[op1]
# for o1, o2 in zip(op1.outputs, op2.outputs):
# del self.equiv[o1]
# def on_rewire(self, clients1, r1, new_r1):
# if (new_r1, r1) in self.inconsistencies:
# self.inconsistencies.remove((new_r1, r1))
# return
# if not self.source.clients(r1):
# try:
# target.replace(self.equiv[r1], self.equiv[new_r1])
# except:
# self.inconsistencies.append((r1, new_r1))
# def validate(self):
# if self.inconsistencies:
# raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
# return True
# return Synchronize
...@@ -3,10 +3,13 @@ import theano.tensor as T ...@@ -3,10 +3,13 @@ import theano.tensor as T
from ...gof import Env from ...gof import Env
import numpy import numpy
from theano.tensor.blas import * from theano.tensor.blas import *
from theano.tensor.blas import _as_scalar, _dot22, _is_real_matrix from theano.tensor.blas import _dot22, res_is_a
from unittest import TestCase from unittest import TestCase
from copy import copy from copy import copy
_as_scalar = GemmLocalOptimizer._as_scalar
_is_real_matrix = GemmLocalOptimizer._is_real_matrix
from theano import In, Out from theano import In, Out
from .test_basic import (_approx_eq, as_tensor, function, from .test_basic import (_approx_eq, as_tensor, function,
compile, value, constant, inplace, eval_outputs) compile, value, constant, inplace, eval_outputs)
...@@ -185,6 +188,15 @@ class t_gemm(TestCase): ...@@ -185,6 +188,15 @@ class t_gemm(TestCase):
return return
self.fail() self.fail()
def test_res_is_a():
X,Y,Z,a,b = XYZab()
assert not res_is_a(a, T.sqrt)
assert not res_is_a(a+a, T.sqrt)
assert res_is_a(T.sqrt(a+a), T.sqrt)
#leave the maxclients stuff untested because it requires being in an env.
class t_as_scalar(TestCase): class t_as_scalar(TestCase):
def test0(self): def test0(self):
"""Test that it works on scalar constants""" """Test that it works on scalar constants"""
...@@ -227,85 +239,167 @@ class T_real_matrix(TestCase): ...@@ -227,85 +239,167 @@ class T_real_matrix(TestCase):
self.failUnless(_is_real_matrix(T.DimShuffle([False,False], [1, 0])(T.dmatrix()))) self.failUnless(_is_real_matrix(T.DimShuffle([False,False], [1, 0])(T.dmatrix())))
self.failUnless(not _is_real_matrix(T.DimShuffle([False], ['x', 0])(T.dvector()))) self.failUnless(not _is_real_matrix(T.DimShuffle([False], ['x', 0])(T.dvector())))
if JOSEPHS_BUG_SOLVED: def fail(msg):
class T_gemm_opt(TestCase): print 'FAIL', msg
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting assert False
functions compute the same things as the originals."""
def XYZab(self): """This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
def just_gemm(self, i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]): class Failure(Exception):
def on_fail(): pass
for node in f.maker.env.toposort():
print 'GRAPH', node
self.fail()
class Warning(Exception):
pass
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
try:
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN') f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: on_fail() if node.op == T.dot: raise Warning('dot in graph')
if node.op == _dot22: on_fail() if node.op == _dot22: raise Warning('_dot22 in graph')
g = function(i, o, mode='FAST_COMPILE') g = function(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes: for node in g.maker.env.nodes:
if node.op == gemm: on_fail() if node.op == gemm: raise Warning('gemm in graph')
rng = numpy.random.RandomState(234) rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234) rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes]) r1 = g(*[rng.randn(*sh) for sh in ishapes])
if numpy.max(numpy.abs(r0[0] - r1[0])) > 1.0e-8: max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
self.fail() if max_abs_err > 1.0e-8:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
except Failure:
for node in f.maker.env.toposort():
print 'GRAPH', node
raise
except Warning:
for node in f.maker.env.toposort():
print 'GRAPH', node
def test0(self):
def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated""" """Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = self.XYZab() X,Y,Z,a,b = XYZab()
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b]) just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z]) just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b]) just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z]) just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
#with transposes (transposes should be pushed through dot in canonicalize) #with transposes (transposes should be pushed through dot in canonicalize)
self.just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)]) just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
self.just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T]) just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
#with N multiplications instead of just one #with N multiplications instead of just one
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b]) just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b]) just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
# with > 2 terms in the overall addition
self.just_gemm([X,Y,Z,a,b], [Z + Z + T.dot(X,Y) + Z])
def test_double_gemm(self): def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder""" """This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar() R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
self.just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T], just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]) ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
def wishlist(self): ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c]
o = [a * T.dot(X,Y) + gemm(Z, b, S.T, R.T, 1.0)]
try:
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Failure('dot in graph')
if node.op == _dot22: raise Failure('_dot22 in graph')
g = function(i, o, mode=compile.Mode(linker='py', optimizer=None))
#for node in g.maker.env.nodes:
# if node.op == gemm: raise Failure('gemm in graph')
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
except Failure:
for node in f.maker.env.toposort():
print 'GRAPH', node
raise
def wishlist_gemm_opt():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term #with >2 additions of the same T.dot(X,Y term
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated.
This adds a vector two the previous test, which triggers the long-sought GEMM bug.
"""
X,Y,Z,a,b = XYZab()
v = T.vector()
def my_just_gemm(o):
i = [X,Y,Z,a,b,v]
ishapes = [(4,3), (3,5), (4,5), (), (), (5,)]
rval = just_gemm(i, o, ishapes=ishapes)
my_just_gemm([v + T.dot(X,Y) * a + Z * b])
my_just_gemm([v + a * T.dot(X,Y) + b * Z])
my_just_gemm([v + b * Z + a * T.dot(X,Y)])
my_just_gemm([v + T.dot(X,Y) * a - Z * b])
my_just_gemm([v + a * T.dot(X,Y) - b * Z])
my_just_gemm([v + b * Z - a * T.dot(X,Y)])
def test_vector_stuff(self): #with N multiplications instead of just one
my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
my_just_gemm([v + Z + T.dot(X,Y)])
my_just_gemm([v + Z*b + T.dot(X,Y)])
my_just_gemm([v + Z + a*b*a*T.dot(X,Y)])
my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
my_just_gemm([Z - T.dot(X,Y) + v])
my_just_gemm([Z*b - T.dot(X,Y) + v])
my_just_gemm([Z - a*b*a*T.dot(X,Y) + v])
def test_gemm_opt_vector_stuff():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector() u,v = T.dvector(), T.dvector()
f = function([a, u, v], a + T.dot(u,v), mode='FAST_RUN') f = function([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
self.failIf(gemm in [n.op for n in f.maker.env.nodes]) if gemm in [n.op for n in f.maker.env.nodes]:
raise Failure('gemm in graph')
f = function([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN') f = function([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN')
self.failIf(gemm in [n.op for n in f.maker.env.nodes]) if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
def test_inplace0():
#should fail to insert gemm because gemm would create cycles
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
f = function([X,Y,Z,a,b, R, S, c],
[Z * (Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
def test_inplace1():
X,Y,Z,a,b = XYZab()
# with > 2 terms in the overall addition
f = function([X,Y,Z,a,b],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
...@@ -155,14 +155,14 @@ class QuadraticDenoisingAA(T.RModule): ...@@ -155,14 +155,14 @@ class QuadraticDenoisingAA(T.RModule):
updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients)) updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients))
# INTERFACE METHODS # INTERFACE METHODS
self.update = theano.Method(self.input, self.ncost, updates) #self.update = theano.Method(self.input, self.ncost, updates)
self.compute_cost = theano.Method(self.input, self.cost) #self.compute_cost = theano.Method(self.input, self.cost)
self.noisify = theano.Method(self.input, self.corrupted_input) #self.noisify = theano.Method(self.input, self.corrupted_input)
self.reconstruction = theano.Method(self.input, self.output) #self.reconstruction = theano.Method(self.input, self.output)
self.representation = theano.Method(self.input, self.hidden) #self.representation = theano.Method(self.input, self.hidden)
self.reconstruction_through_noise = theano.Method(self.input, [self.corrupted_input, self.noutput]) #self.reconstruction_through_noise = theano.Method(self.input, [self.corrupted_input, self.noutput])
self.validate = theano.Method(self.input, [self.cost, self.output]) #self.validate = theano.Method(self.input, [self.cost, self.output])
def _instance_initialize(self, obj, input_size, hidden_size, seed, lr, qfilter_relscale): def _instance_initialize(self, obj, input_size, hidden_size, seed, lr, qfilter_relscale):
""" """
...@@ -291,16 +291,16 @@ class Module_Nclass(module.FancyModule): ...@@ -291,16 +291,16 @@ class Module_Nclass(module.FancyModule):
#define the apply method #define the apply method
self.pred = T.argmax(linear_output, axis=1) self.pred = T.argmax(linear_output, axis=1)
self.apply = module.Method([self.input], self.pred) #self.apply = module.Method([self.input], self.pred)
self.validate = module.Method([self.input, self.targ], [self.cost, self.argmax, self.max_pr]) #self.validate = module.Method([self.input, self.targ], [self.cost, self.argmax, self.max_pr])
self.softmax_output = module.Method([self.input], self.softmax_unsupervised) #self.softmax_output = module.Method([self.input], self.softmax_unsupervised)
if self.params: if self.params:
gparams = T.grad(sum_xent, self.params) gparams = T.grad(sum_xent, self.params)
self.update = module.Method([self.input, self.targ], sum_xent, #self.update = module.Method([self.input, self.targ], sum_xent,
updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams))) #updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
class ConvolutionalMLPInstance(module.FancyModuleInstance, Loss01): class ConvolutionalMLPInstance(module.FancyModuleInstance, Loss01):
#initialize is called by Module.make #initialize is called by Module.make
...@@ -366,11 +366,6 @@ class ConvolutionalMLP(module.FancyModule): ...@@ -366,11 +366,6 @@ class ConvolutionalMLP(module.FancyModule):
) )
) )
# to_update = []
# all_kits = []
# input_update = self.input_representations[0].update
# input_update.resolve_all()
for i in self.inputs[1:]: for i in self.inputs[1:]:
self.input_representations.append( self.input_representations.append(
QDAA( QDAA(
...@@ -411,11 +406,17 @@ class ConvolutionalMLP(module.FancyModule): ...@@ -411,11 +406,17 @@ class ConvolutionalMLP(module.FancyModule):
] + self.hidden.qfilters ] + self.hidden.qfilters
input_pretraining_cost = sum(i.ncost for i in self.input_representations) input_pretraining_cost = sum(i.ncost for i in self.input_representations)
hidden_pretraining_cost = self.hidden.ncost hidden_pretraining_cost = self.hidden.ncost
input_pretraining_gradients = T.grad(input_pretraining_cost, input_pretraining_params) input_pretraining_gradients = T.grad(input_pretraining_cost,
input_pretraining_params)
hidden_pretraining_gradients = T.grad(hidden_pretraining_cost, hidden_pretraining_params) hidden_pretraining_gradients = T.grad(hidden_pretraining_cost, hidden_pretraining_params)
pretraining_updates = dict((p, p - self.lr * g) for p, g in zip(input_pretraining_params, input_pretraining_gradients) + pretraining_updates = \
zip(hidden_pretraining_params, hidden_pretraining_gradients)) dict((p, p - self.lr * g) for p, g in \
self.pretraining_update = module.Method(self.inputs, [input_pretraining_cost, hidden_pretraining_cost], pretraining_updates) zip(input_pretraining_params, input_pretraining_gradients) \
+ zip(hidden_pretraining_params, hidden_pretraining_gradients))
self.pretraining_update = module.Method(self.inputs,
[input_pretraining_cost, hidden_pretraining_cost],
pretraining_updates)
finetuning_params = \ finetuning_params = \
[self.input_representations[0].w1, self.input_representations[0].b1] + self.input_representations[0].qfilters + \ [self.input_representations[0].w1, self.input_representations[0].b1] + self.input_representations[0].qfilters + \
...@@ -426,9 +427,8 @@ class ConvolutionalMLP(module.FancyModule): ...@@ -426,9 +427,8 @@ class ConvolutionalMLP(module.FancyModule):
finetuning_updates = dict((p, p - self.lr * g) for p, g in zip(finetuning_params, finetuning_gradients)) finetuning_updates = dict((p, p - self.lr * g) for p, g in zip(finetuning_params, finetuning_gradients))
self.finetuning_update = module.Method(self.inputs + [self.targ], self.output.cost, finetuning_updates) self.finetuning_update = module.Method(self.inputs + [self.targ], self.output.cost, finetuning_updates)
#self.validate = module.Method(self.inputs + [self.targ], [self.output.cost, self.output.argmax, self.output.max_pr])
self.validate = module.Method(self.inputs + [self.targ], [self.output.cost, self.output.argmax, self.output.max_pr]) #self.softmax_output = module.Method(self.inputs, self.output.softmax_unsupervised)
self.softmax_output = module.Method(self.inputs, self.output.softmax_unsupervised)
def create(window_size=3, def create(window_size=3,
input_dimension=9, input_dimension=9,
...@@ -462,15 +462,21 @@ JTEST = theano.compile.mode.optdb.query(*sys.argv[2:]) ...@@ -462,15 +462,21 @@ JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print 'JTEST', JTEST print 'JTEST', JTEST
theano.compile.register_optimizer('JTEST', JTEST) theano.compile.register_optimizer('JTEST', JTEST)
if __name__ == '__main__': if __name__ == '__main__':
optimizer = eval(sys.argv[1]) optimizer = eval(sys.argv[1])
m = create(compile_mode = theano.Mode(linker='c|py', optimizer=optimizer)) m = create(compile_mode = theano.Mode(linker='c|py', optimizer=optimizer))
prog_str = [] prog_str = []
for i, node in enumerate(m.finetuning_update.maker.env.toposort()): idx_of_node = {}
#print ' ', i, node for i, node in enumerate(m.pretraining_update.maker.env.toposort()):
idx_of_node[node] = i
if False and i > -1:
print ' ', i, node, [(ii, idx_of_node.get(ii.owner, 'IN')) for ii in node.inputs]
prog_str.append(str(node)) prog_str.append(str(node))
print "PROGRAM LEN %i HASH %i"% (len(m.finetuning_update.maker.env.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str)) #print input_pretraining_gradients[4].owner.inputs
#print input_pretraining_gradients[4].owner.inputs[1].owner.inputs
#sys.exit()
print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.env.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
rng = N.random.RandomState(23904) rng = N.random.RandomState(23904)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论