提交 b6b2c608 authored 作者: James Bergstra's avatar James Bergstra

code in a mess, but gemm-optimization works on more systematic test cases…

code in a mess, but gemm-optimization works on more systematic test cases including josephs NAACL graph
上级 43291f46
......@@ -63,11 +63,20 @@ def register_optimizer(name, opt):
raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt
class AddDestroyHandler(gof.Optimizer):
def apply(self, env):
pass
def add_requirements(self, env):
super(AddDestroyHandler, self).add_requirements(env)
env.extend(gof.DestroyHandler())
optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), 1, 'fast_run')
optdb.register('specialize', gof.EquilibriumDB(), 2, 'fast_run')
optdb.register('merge2', gof.EquilibriumDB(), 100, 'fast_run')
optdb.register('merge2', gof.EquilibriumDB(), 49, 'fast_run')
optdb.register('add_destroy_handler', AddDestroyHandler(), 49.5, 'fast_run', 'inplace')
optdb.register('merge3', gof.EquilibriumDB(), 100, 'fast_run')
class Mode(object):
......
......@@ -20,15 +20,14 @@ from link import \
from op import \
Op
from opt import \
Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, \
OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, \
keep_going, warn, \
InplaceOptimizer, PureThenInplaceOptimizer
#LocalOpKeyOptGroup, OpKeyOptimizer
from opt import (Optimizer, optimizer, SeqOptimizer,
MergeOptimizer, MergeOptMerge,
LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
keep_going, warn,
InplaceOptimizer, PureThenInplaceOptimizer,
OpKeyOptimizer)
from optdb import \
DB, Query, \
......
......@@ -265,6 +265,11 @@ class LocalOptimizer(object):
raise utils.AbstractFunctionError()
def add_requirements(self, env):
"""If this local optimization wants to add some requirements to the env,
This is the place to do it."""
env.extend(toolbox.ReplaceValidate())
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
......@@ -273,8 +278,6 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
self._tracks = tracks
def tracks(self):
return self._tracks
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
def __str__(self):
return getattr(self, 'name', '<FromFunctionLocalOptimizer instance>')
......@@ -551,7 +554,7 @@ class NavigatorOptimizer(Optimizer):
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
"""
:param local_opt: a LocalOptimizer to apply over a Env.
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
......@@ -617,6 +620,24 @@ class NavigatorOptimizer(Optimizer):
env.remove_feature(u)
def process_node(self, env, node, lopt = None):
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
return either False or a list of Results that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is successful, and this
function returns True.
If there are no replacement candidates or the env rejects the replacements, this
function returns False.
:param env: an Env
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for how to compute
node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
"""
lopt = lopt or self.local_opt
try:
replacements = lopt.transform(node)
......@@ -633,23 +654,21 @@ class NavigatorOptimizer(Optimizer):
env.replace_all_validate(repl_pairs)
return True
except Exception, e:
# This means the replacements were rejected by the env.
#
# This is not supposed to happen. The default failure_callback will print a
# traceback as a warning.
if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs)
#DEBUG DONT PUSH
#print lopt
#print dir(lopt)
#raise
#END
return False
else:
raise
def add_requirements(self, env):
super(NavigatorOptimizer, self).add_requirements(env)
env.extend(toolbox.ReplaceValidate())
if self.local_opt:
self.local_opt.add_requirements(env)
class TopoOptimizer(NavigatorOptimizer):
"""WRITEME"""
......@@ -722,7 +741,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
- NodeFinder
- ReplaceValidate
"""
NavigatorOptimizer.add_requirements(self, env)
super(OpKeyOptimizer, self).add_requirements(env)
env.extend(toolbox.NodeFinder())
......
......@@ -13,6 +13,8 @@ class DB(object):
def __init__(self):
self.__db__ = defaultdict(set)
self._names = set()
self.name = None #will be reset by register
#(via obj.name by the thing doing the registering)
def register(self, name, obj, *tags):
# N.B. obj is not an instance of class Optimizer.
......@@ -21,6 +23,8 @@ class DB(object):
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise Exception('wtf', obj)
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
......@@ -118,9 +122,10 @@ class EquilibriumDB(DB):
class SequenceDB(DB):
def __init__(self):
def __init__(self, failure_callback = opt.warn):
super(SequenceDB, self).__init__()
self.__priority__ = {}
self.failure_callback = failure_callback
def register(self, name, obj, priority, *tags):
super(SequenceDB, self).register(name, obj, *tags)
......@@ -130,6 +135,6 @@ class SequenceDB(DB):
opts = super(SequenceDB, self).query(*tags, **kwtags)
opts = list(opts)
opts.sort(key = lambda obj: self.__priority__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = opt.warn)
return opt.SeqOptimizer(opts, failure_callback = self.failure_callback)
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import os, sys
import os, sys, traceback
import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub,
InplaceOptimizer, SeqOptimizer, warn, local_optimizer)
from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, warn, local_optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError)
from ..printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer
import basic as T
from ..tensor import as_tensor
#NB: this clobbers the builtin 'compile' symbol
from .. import compile #to register the optimizer built by this file
from .blas_headers import cblas_header_text, blas_header_text
JOSEPHS_BUG_SOLVED = False
@utils.memoize
def ldflags():
"""Return a list of libraries against which an Op's object file should be
......@@ -270,7 +267,7 @@ class Gemm(GemmRelated):
E_z_uniq = 'argument z aliased to x or y'
destroy_map = {0: [0]}
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
inputs = map(T.as_tensor, inputs)
if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
z, a, x, y, b = inputs
......@@ -348,19 +345,215 @@ class Gemm(GemmRelated):
#undef REAL
"""
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub):
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
def res_is_a(node, op, maxclients=None):
return node.owner \
and node.owner.op == op \
and (len(node.clients) <= maxclients if maxclients is not None else True)
class GemmLocalOptimizer(LocalOptimizer):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
def __init__(self):
super(LocalOptimizer, self).__init__()
def op_key(self):
return [T.add, T.sub]
def add_requirements(self, env):
super(GemmLocalOptimizer,self).add_requirements(env)
env.extend(DestroyHandler())
def transform(self, node):
_as_scalar, _is_real_matrix, _as_isolated_scalar_times_matrix, beta_L_plus_alpha_M\
= (GemmLocalOptimizer._as_scalar,
GemmLocalOptimizer._is_real_matrix,
GemmLocalOptimizer._as_isolated_scalar_times_matrix,
GemmLocalOptimizer.beta_L_plus_alpha_M)
if node.op == T.sub:
L, R = node.inputs
if not _is_real_matrix(L):
return False
if not _is_real_matrix(R):
return False
tmp = _as_isolated_scalar_times_matrix(L)
try:
sL, mL = tmp
except:
sL, mL = 1.0, L
tmp = _as_isolated_scalar_times_matrix(R)
try:
sR, mR = tmp
except:
sR, mR = 1.0, R
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
if node.op == T.add:
sM_list = []
other_inputs = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
if tmp:
sM_list.append(tmp)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
else:
other_inputs.append(input)
if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
#we turned the two candidates into a gemm
# now we have to add the other_inputs and return the replacement graph
if other_inputs:
return [T.add(*(other_inputs + gemm_of_sM_list))]
else:
return gemm_of_sM_list
else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + gemm_of_sM_list + other_inputs))]
return False
@staticmethod
def failure_callback(exc, nav, repl_pairs):
"""WRITEME"""
if not isinstance(exc, InconsistencyError):
traceback.print_exc()
else:
print 'GEMM caused cycle, forget it.'
@staticmethod
def _as_scalar(res):
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle):
return GemmLocalOptimizer._as_scalar(res.owner.inputs[0])
elif res.type in T.float_scalar_types:
return res
elif isinstance(res, T.Constant) and res.data.size == 1:
return res.data.flatten()[0]
else:
return None
@staticmethod
def _is_real_matrix(res):
return res.type in T.float_matrix_types \
and res.broadcastable[0] == False \
and res.broadcastable[1] == False #cope with tuple vs. list
@staticmethod
def _as_isolated_scalar_times_matrix(res):
_as_scalar, _is_real_matrix, _as_isolated_scalar_times_matrix, beta_L_plus_alpha_M\
= (GemmLocalOptimizer._as_scalar,
GemmLocalOptimizer._is_real_matrix,
GemmLocalOptimizer._as_isolated_scalar_times_matrix,
GemmLocalOptimizer.beta_L_plus_alpha_M)
if res_is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if (sL is not None) and _is_real_matrix(R):
return (sL, R)
if (sR is not None) and _is_real_matrix(L):
return (sR, L)
else:
scalars = []
matrices = []
for input in res.owner.inputs:
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif _is_real_matrix(input):
matrices.append(input)
else:
return None
if len(matrices) == 1:
rval = (T.mul(*scalars), matrices[0])
return rval
@staticmethod
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if True:
if res_is_a(L, T.sqrt):
print 'CLIENTS OF L', L, L.clients
if res_is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
print 'GEMM 0', rval, beta, L, alpha, M
return rval
if False and res_is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L
if res_is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
print 'GEMM 1', rval
return rval
if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
print 'GEMM 2', rval
return rval
if (1.0 != alpha):
#at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
print 'GEMM 3', rval
return rval
if recurse_flip:
return GemmLocalOptimizer.beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
else:
return False
#I think that three passes should suffice to catch all the GEMMs.
# TODO: This could be an equilibriumOptmizer, but I don't know how to combine an OpKeyOptimizer and
# an EquilibriumOptimizer.
compile.optdb.register('inplace_gemm_0', OpKeyOptimizer(GemmLocalOptimizer(),
failure_callback=GemmLocalOptimizer.failure_callback), 70.00, 'fast_run', 'inplace')
compile.optdb.register('inplace_gemm_1', OpKeyOptimizer(GemmLocalOptimizer(),
failure_callback=GemmLocalOptimizer.failure_callback), 70.01, 'fast_run', 'inplace')
compile.optdb.register('inplace_gemm_2', OpKeyOptimizer(GemmLocalOptimizer(),
failure_callback=GemmLocalOptimizer.failure_callback), 70.02, 'fast_run', 'inplace')
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert _is_real_matrix(x)
assert GemmLocalOptimizer._is_real_matrix(x)
assert y.type == x.type #makes sure y is a matrix
bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz)]
......@@ -404,7 +597,7 @@ class Dot22(GemmRelated):
double a = 1.0;
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y), (_z, ), sub):
def c_code(self, node, name, (_x, _y), (_z, ), sub): #DEBUG
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
_dot22 = Dot22()
......@@ -413,158 +606,9 @@ _dot22 = Dot22()
def local_dot_to_dot22(node):
if node.op == T.dot:
x,y = node.inputs
if _is_real_matrix(x) and y.type == x.type:
if GemmLocalOptimizer._is_real_matrix(x) and y.type == x.type:
return [_dot22(*node.inputs)]
else:
return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_dot_to_dot22)
def _is_a(node, op, maxclients=None):
return node.owner \
and node.owner.op == op \
and len(node.clients) <= maxclients if maxclients is not None else True
def _as_scalar(res):
"""Return None or a TensorResult whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle):
return _as_scalar(res.owner.inputs[0])
elif res.type in T.float_scalar_types:
return res
elif isinstance(res, T.Constant) and res.data.size == 1:
return res.data.flatten()[0]
else:
return None
def _is_real_matrix(res):
return res.type in T.float_matrix_types \
and res.broadcastable[0] == False \
and res.broadcastable[1] == False #cope with tuple vs. list
def _as_isolated_scalar_times_matrix(res):
if _is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if (sL is not None) and _is_real_matrix(R):
return (sL, R)
if (sR is not None) and _is_real_matrix(L):
return (sR, L)
else:
scalars = []
matrices = []
for input in res.owner.inputs:
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif _is_real_matrix(input):
matrices.append(input)
else:
return None
if len(matrices) == 1:
rval = (T.mul(*scalars), matrices[0])
return rval
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if _is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
return rval
if _is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
G, a, u, v, b = M.owner.inputs
#print 'GEMM', G, L
if _is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#print 'GEMM 1', G, L
rval = [gemm(gemm(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
return rval
elif G is L:
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm(L, alpha*a, u, v, alpha * b + beta)]
#print 'GEMM 2', rval
return rval
elif 1.0 != alpha:
#at the very least, move the alpha inside the gemm
rval = [beta * L + gemm(G, alpha * a, u, v, alpha * b)]
#print 'GEMM 3', G, L
return rval
if recurse_flip:
return beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
else:
return False
@local_optimizer([T.sub])
def local_sub_to_gemm(node):
if node.op == T.sub:
L, R = node.inputs
if not _is_real_matrix(L):
return False
if not _is_real_matrix(R):
return False
tmp = _as_isolated_scalar_times_matrix(L)
try:
sL, mL = tmp
except:
sL, mL = 1.0, L
tmp = _as_isolated_scalar_times_matrix(R)
try:
sR, mR = tmp
except:
sR, mR = 1.0, R
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_sub_to_gemm)
@local_optimizer([T.add])
def local_add_to_gemm(node):
"""This is a massive beast for recognizing all the ways that a subtraction could be
replaced by a GEMM
It depends on `local_transposed_dot` to canonicalize the graph a bit by swapping
dot(a,b).T -> dot(b.T, a.T)
"""
if node.op == T.add:
sM_list = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
if tmp:
sM_list.append(tmp)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
if len(sM_list) == 2:
sL, mL = sM_list[0]
sR, mR = sM_list[1]
return beta_L_plus_alpha_M(sL, mL, sR, mR)
else:
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
sL, mL = sM_list[i]
sR, mR = sM_list[j]
rval = beta_L_plus_alpha_M(sL, mL, sR, mR)
if rval:
assert len(rval) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))]
return False
if JOSEPHS_BUG_SOLVED:
register_specialize(local_add_to_gemm)
register_specialize(local_dot_to_dot22)
......@@ -316,7 +316,7 @@ class Elemwise(Op):
scalars
* inplace_pattern: a dictionary that maps the index of an output to the
index of an input so the output is calculated inplace using
the input's storage.
the input's storage. (Just like destroymap, but without the lists.)
"""
self.name = name
self.scalar_op = scalar_op
......@@ -357,16 +357,21 @@ class Elemwise(Op):
args.append(input)
else:
# TODO: use LComplete instead
args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length), inplace = True)(input))
args.append(DimShuffle(
input.type.broadcastable,
['x']*difference + range(length),
inplace = True)(input))
inputs = args
# # Following conditions should always be true?
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
#HERE: all the broadcast dims have the same length now
#cleverness: we iterate over the first, second, third broadcast flag of all inputs in
#parallel... the all() gives us each output broadcastable bit in turn.
#it is multiplied by nout because Elemwise supports multiple outputs (nout of them)
out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout
#inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern
if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items():
......@@ -374,21 +379,32 @@ class Elemwise(Op):
if ib and not ob:
raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.")
out_dtypes = [o.type.dtype for o in shadow.outputs]
if any(inputs[i].type.dtype != out_dtypes[o] for i, o in inplace_pattern.items()):
raise TypeError("Cannot do an inplace operation on incompatible data types.", [i.type.dtype for i in inputs], out_dtypes)
if any(inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()):
raise TypeError("Cannot do an inplace operation on incompatible data types.",
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern))
outputs = [Tensor(dtype = dtype, broadcastable = broadcastable)() for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
return Apply(self, inputs, outputs)
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern
if type(self) == type(other):
items = self.inplace_pattern.items()
other_items = other.inplace_pattern.items()
items.sort()
other_items.sort()
return self.scalar_op == other.scalar_op and items == other_items
return False
def __hash__(self):
return hash(self.scalar_op) ^ hash(tuple(self.inplace_pattern.items()))
items = self.inplace_pattern.items()
items.sort()
return hash(self.scalar_op) ^ hash(tuple(items))
def __str__(self):
if self.name is None:
if self.inplace_pattern:
return "Elemwise{%s}%s" % (self.scalar_op, str(self.inplace_pattern))
items = self.inplace_pattern.items()
items.sort()
return "Elemwise{%s}%s" % (self.scalar_op, str(items))
else:
return "Elemwise{%s}" % (self.scalar_op)
else:
......@@ -467,6 +483,7 @@ class Elemwise(Op):
storage[0] = odat
else:
for i, (output, storage) in enumerate(zip(node.outputs, output_storage)):
#i is an output idx
if i in self.inplace_pattern:
odat = inputs[self.inplace_pattern[i]]
else:
......@@ -500,7 +517,7 @@ class Elemwise(Op):
defines = ""
undefs = ""
dmap = dict([(node.outputs[i], [node.inputs[o]]) for i, o in self.inplace_pattern.items()])
dmap = dict([(node.outputs[o], [node.inputs[i]]) for o, i in self.inplace_pattern.items()])
idtypes = [input.type.dtype_specs()[1] for input in inputs]
......
......@@ -5,7 +5,7 @@
from .. import gof
from ..gof import opt
from ..gof import opt, InconsistencyError
from elemwise import Elemwise, DimShuffle
from .. import scalar
import basic as T
......@@ -32,7 +32,8 @@ def in2out(*local_opts, **kwargs):
def _insert_inplace_optimizer(env):
@gof.optimizer
def insert_inplace_optimizer(env):
"""
Usage: inplace_optimizer.optimize(env)
......@@ -59,17 +60,18 @@ def _insert_inplace_optimizer(env):
try:
new = Elemwise(
op.scalar_op.__class__(
scalar.transfer_type(*[inplace_pattern.get(i, None) for i in xrange(len(node.outputs))])),
scalar.transfer_type(
*[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))])),
inplace_pattern).make_node(*node.inputs)
env.replace_all_validate(zip(node.outputs, new.outputs))
except Exception, e:
except (ValueError, TypeError, InconsistencyError), e:
continue
candidate_inputs.remove(candidate_input)
node = new
baseline = inplace_pattern
break
insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
......@@ -310,7 +312,7 @@ def local_fill_cut(node):
register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' )
#register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy' ) #DEBUG
@gof.local_optimizer([None, T.fill])
def local_fill_sink(node):
......@@ -650,38 +652,6 @@ def local_mul_specialize(node):
return False
register_specialize(local_mul_specialize)
if 0: #TODO: replace this with a c version of any InplaceDimShuffle
class _TransposeInplace(T.Op):
view_map = {0: [0]}
def make_node(self, input):
return T.Apply(self, [input],
[T.tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "_TransposeInplace"
_transpose_inplace = _TransposeInplace()
@gof.local_optimizer([T.DimShuffle([False,False],[1,0],inplace=True)])
def local_dimshuffle_transposeinplace(node):
if node.op == T.DimShuffle([False,False],[1,0],inplace=True):
return [_transpose_inplace(node.inputs[0])]
return False
register_specialize(local_dimshuffle_transposeinplace)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
......@@ -844,287 +814,3 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# def _math_optimizer():
# pass_1 = in2out(local_fill_sink)
# pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
# pass_3 = out2in(local_subtensor_make_vector, local_fill_cut)
# canonizer = in2out(local_add_canonizer,
# local_mul_canonizer,
# local_fill_sink)
# pass_4 = out2in(local_greedy_distributor)
# return gof.SeqOptimizer(pass_1,
# pass_2,
# pass_3,
# neg_to_mul,
# canonizer,
# pass_4,
# mul_to_neg)
# math_optimizer = _math_optimizer()
# compile.register_optimizer('math',
# gof.MergeOptMerge(
# gof.PureThenInplaceOptimizer(
# math_optimizer,
# inplace_optimizer)))
# compile.register_mode('SANITY_CHECK', compile.Mode('c&py', 'math'))
# compile.register_mode('FAST_RUN', compile.Mode('c|py', 'math'))
# compile.register_mode('EXPENSIVE_OPTIMIZATIONS', compile.Mode('c|py', 'math'))
# @gof.local_optimizer
# def local_clique_fusion(node):
# aaaaaaaaaaaaaaaaaaaaaaa
# def find_cliques(env, through_broadcast = False):
# """
# Usage: find_cliques(env, through_broadcast = False)
# Returns a list of pairs where each pair contains a list
# of inputs and a list of outputs such that Env(inputs, outputs)
# contains nothing but Broadcast Ops.
# If through_broadcast is False, the cliques will only be
# allowed to broadcast over the inputs, which means, for
# example, that vector operations will not be mixed with
# matrix operations.
# """
# def seek_from(r):
# # walks through the graph until it encounters a
# # non-Broadcast operation or (if through_broadcast
# # is False) a Result which needs to be broadcasted.
# op = r.owner
# if env.edge(r) \
# or not isinstance(op, Broadcast) \
# or len(op.outputs) > 1:
# # todo: handle multiple-output broadcast ops
# # (needs to update the clique's outputs)
# return None
# ret = set()
# if not through_broadcast:
# # check each dimension over all the inputs - if the broadcastable
# # fields are not all 0 or all 1 for a particular dimension, then
# # broadcasting will be performed along it on the inputs where the
# # value is 1 and we will stop.
# if any(any(bc) and not all(bc)
# for bc in zip(*[input.broadcastable for input in op.inputs])):
# ret.update(op.inputs)
# return ret
# for input in op.inputs:
# res = seek_from(input)
# if res is None:
# # input is a leaf of our search
# ret.add(input)
# else:
# ret.update(res)
# return ret
# cliques = []
# def find_cliques_helper(r):
# if env.edge(r):
# return
# clique_inputs = seek_from(r)
# if clique_inputs is None:
# # Not in a clique, keep going
# op = r.owner
# if op is not None:
# for input in op.inputs:
# find_cliques_helper(input)
# else:
# # We found a clique, add it to the list and
# # jump to the leaves.
# cliques.append((clique_inputs, [r]))
# for input in clique_inputs:
# find_cliques_helper(input)
# for output in env.outputs:
# find_cliques_helper(output)
# # todo: merge the cliques if possible
# return cliques
# class CliqueOptimizer(opt.Optimizer):
# """
# Usage: CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = False).optimize(env)
# Finds cliques of Broadcast operations in the env and does either
# or both of two things:
# * Apply scalar_optimizer on the clique as if the clique was a
# group of scalar operations. scalar_optimizer can be any optimization
# which applies on scalars. If it is None, no optimization is done.
# * Replace the clique with a single Op, optimized to perform the
# computations properly. If make_composite is False, no such replacement
# is done.
# Note: it is recommended to run the lift_dimshuffle optimization before
# this one.
# """
# def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
# self.through_broadcast = through_broadcast
# self.scalar_optimizer = scalar_optimizer
# self.make_composite = make_composite
# def apply(self, env):
# if self.scalar_optimizer is None and not self.make_composite:
# # there's nothing to do with the cliques...
# return
# cliques = find_cliques(env, self.through_broadcast)
# opt = self.scalar_optimizer
# def build_scalar_clique(r, env, equiv):
# # Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# # structure and equivalent operations. equiv contains the mapping.
# if r in equiv:
# return equiv[r]
# op = r.owner
# if env.edge(r):
# # For each leave we make a Scalar of the corresponding dtype
# s = scalar.Scalar(dtype = r.dtype)
# _r = r
# if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
# _r = r.owner.inputs[0]
# if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
# and _r.broadcastable == ():
# # If we have a constant tensor we map it to a constant scalar.
# s.data = _r.data
# s.constant = True
# equiv[r] = s
# return s
# s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
# equiv[op] = s_op
# for output, s_output in zip(op.outputs, s_op.outputs):
# equiv[output] = s_output
# return equiv[r]
# for c_in, c_out in cliques:
# equiv = dict()
# g = Env(c_in, c_out)
# for output in c_out:
# build_scalar_clique(output, g, equiv)
# s_g = Env([equiv[r] for r in g.inputs],
# [equiv[r] for r in g.outputs])
# if opt is not None:
# equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
# for k, v in equiv.items():
# equiv2[v] = k
# def transform(op, equiv):
# # We get a scalar op and we return an equivalent op on tensors.
# return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
# s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
# opt.optimize(s_g)
# if self.make_composite:
# def follow_inplace(r):
# # Tries to find the earliest r2 in g such that r destroys r2
# # If no such r2 is found, returns None
# op = r.owner
# if op is None or r in g.inputs or r in g.orphans():
# return None
# assert isinstance(op, Broadcast)
# destroyed = op.destroy_map().get(r, None)
# if destroyed is None:
# return None
# else:
# r2 = destroyed[0]
# ret = follow_inplace(r2)
# if ret is None:
# return r2
# else:
# return ret
# inplace_pattern = {}
# for i, output in enumerate(g.outputs):
# destroyed = follow_inplace(output)
# if destroyed is not None and destroyed in g.inputs:
# # we transfer the inplace operation only if it is
# # an input that is destroyed
# inplace_pattern[i] = g.inputs.index(destroyed)
# C = scalar.composite(s_g.inputs, s_g.outputs)
# ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
# env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
# def sync_to(target, equiv, transform):
# """
# Usage: sync_to(target, equiv, transform)
# * target: an Env
# * equiv: a dictionary that maps results and ops to results and ops
# in target
# * transform: a function that takes (op, equiv) as inputs and
# returns a new op.
# Returns a Feature that can be added to an Env and mirrors all
# modifications to that env with modifications to the target env.
# """
# class Synchronize(gof.Listener, gof.Constraint):
# def __init__(self, source):
# self.source = source
# self.target = target
# self.equiv = equiv
# self.transform = transform
# self.inconsistencies = []
# def on_import(self, op1):
# if op1 not in self.equiv:
# op2 = self.transform(op1, self.equiv)
# self.equiv[op1] = op2
# for o1, o2 in zip(op1.outputs, op2.outputs):
# self.equiv[o1] = o2
# def on_prune(self, op1):
# if op1 in self.equiv:
# op2 = self.equiv[op1]
# del self.equiv[op1]
# for o1, o2 in zip(op1.outputs, op2.outputs):
# del self.equiv[o1]
# def on_rewire(self, clients1, r1, new_r1):
# if (new_r1, r1) in self.inconsistencies:
# self.inconsistencies.remove((new_r1, r1))
# return
# if not self.source.clients(r1):
# try:
# target.replace(self.equiv[r1], self.equiv[new_r1])
# except:
# self.inconsistencies.append((r1, new_r1))
# def validate(self):
# if self.inconsistencies:
# raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
# return True
# return Synchronize
......@@ -3,10 +3,13 @@ import theano.tensor as T
from ...gof import Env
import numpy
from theano.tensor.blas import *
from theano.tensor.blas import _as_scalar, _dot22, _is_real_matrix
from theano.tensor.blas import _dot22, res_is_a
from unittest import TestCase
from copy import copy
_as_scalar = GemmLocalOptimizer._as_scalar
_is_real_matrix = GemmLocalOptimizer._is_real_matrix
from theano import In, Out
from .test_basic import (_approx_eq, as_tensor, function,
compile, value, constant, inplace, eval_outputs)
......@@ -185,6 +188,15 @@ class t_gemm(TestCase):
return
self.fail()
def test_res_is_a():
X,Y,Z,a,b = XYZab()
assert not res_is_a(a, T.sqrt)
assert not res_is_a(a+a, T.sqrt)
assert res_is_a(T.sqrt(a+a), T.sqrt)
#leave the maxclients stuff untested because it requires being in an env.
class t_as_scalar(TestCase):
def test0(self):
"""Test that it works on scalar constants"""
......@@ -227,85 +239,167 @@ class T_real_matrix(TestCase):
self.failUnless(_is_real_matrix(T.DimShuffle([False,False], [1, 0])(T.dmatrix())))
self.failUnless(not _is_real_matrix(T.DimShuffle([False], ['x', 0])(T.dvector())))
if JOSEPHS_BUG_SOLVED:
class T_gemm_opt(TestCase):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab(self):
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
def just_gemm(self, i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
def on_fail():
for node in f.maker.env.toposort():
print 'GRAPH', node
self.fail()
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: on_fail()
if node.op == _dot22: on_fail()
g = function(i, o, mode='FAST_COMPILE')
for node in g.maker.env.nodes:
if node.op == gemm: on_fail()
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
if numpy.max(numpy.abs(r0[0] - r1[0])) > 1.0e-8:
self.fail()
def test0(self):
"""Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = self.XYZab()
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
self.just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
self.just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
#with transposes (transposes should be pushed through dot in canonicalize)
self.just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
self.just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
#with N multiplications instead of just one
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
self.just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
# with > 2 terms in the overall addition
self.just_gemm([X,Y,Z,a,b], [Z + Z + T.dot(X,Y) + Z])
def test_double_gemm(self):
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
self.just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
def wishlist(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
def test_vector_stuff(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector()
f = function([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
self.failIf(gemm in [n.op for n in f.maker.env.nodes])
f = function([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN')
self.failIf(gemm in [n.op for n in f.maker.env.nodes])
def fail(msg):
print 'FAIL', msg
assert False
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
class Failure(Exception):
pass
class Warning(Exception):
pass
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
try:
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot in graph')
if node.op == _dot22: raise Warning('_dot22 in graph')
g = function(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes:
if node.op == gemm: raise Warning('gemm in graph')
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
except Failure:
for node in f.maker.env.toposort():
print 'GRAPH', node
raise
except Warning:
for node in f.maker.env.toposort():
print 'GRAPH', node
def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated"""
X,Y,Z,a,b = XYZab()
just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a + Z * b])
just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) + b * Z])
just_gemm([X,Y,Z,a,b], [b * Z + a * T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [T.dot(X,Y) * a - Z * b])
just_gemm([X,Y,Z,a,b], [a * T.dot(X,Y) - b * Z])
just_gemm([X,Y,Z,a,b], [b * Z - a * T.dot(X,Y)])
#with transposes (transposes should be pushed through dot in canonicalize)
just_gemm([X,Y,Z,a,b], [b * Z.T - a * T.dot(Y.T,X.T)])
just_gemm([X,Y,Z,a,b], [b * Z.T + a * b * T.dot(X,Y).T])
#with N multiplications instead of just one
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z*b + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z + a*b*a*T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
just_gemm([X,Y,Z,a,b], [Z - T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z*b - T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z - a*b*a*T.dot(X,Y)])
def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c]
o = [a * T.dot(X,Y) + gemm(Z, b, S.T, R.T, 1.0)]
try:
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Failure('dot in graph')
if node.op == _dot22: raise Failure('_dot22 in graph')
g = function(i, o, mode=compile.Mode(linker='py', optimizer=None))
#for node in g.maker.env.nodes:
# if node.op == gemm: raise Failure('gemm in graph')
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(234)
r1 = g(*[rng.randn(*sh) for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
except Failure:
for node in f.maker.env.toposort():
print 'GRAPH', node
raise
def wishlist_gemm_opt():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated.
This adds a vector two the previous test, which triggers the long-sought GEMM bug.
"""
X,Y,Z,a,b = XYZab()
v = T.vector()
def my_just_gemm(o):
i = [X,Y,Z,a,b,v]
ishapes = [(4,3), (3,5), (4,5), (), (), (5,)]
rval = just_gemm(i, o, ishapes=ishapes)
my_just_gemm([v + T.dot(X,Y) * a + Z * b])
my_just_gemm([v + a * T.dot(X,Y) + b * Z])
my_just_gemm([v + b * Z + a * T.dot(X,Y)])
my_just_gemm([v + T.dot(X,Y) * a - Z * b])
my_just_gemm([v + a * T.dot(X,Y) - b * Z])
my_just_gemm([v + b * Z - a * T.dot(X,Y)])
#with N multiplications instead of just one
my_just_gemm([v + (b * b) * Z * a + (a * a) * T.dot(X,Y) * b])
my_just_gemm([v + Z + T.dot(X,Y)])
my_just_gemm([v + Z*b + T.dot(X,Y)])
my_just_gemm([v + Z + a*b*a*T.dot(X,Y)])
my_just_gemm([v + (b * b) * Z * a - (a * a) * T.dot(X,Y) * b])
my_just_gemm([Z - T.dot(X,Y) + v])
my_just_gemm([Z*b - T.dot(X,Y) + v])
my_just_gemm([Z - a*b*a*T.dot(X,Y) + v])
def test_gemm_opt_vector_stuff():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector()
f = function([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
if gemm in [n.op for n in f.maker.env.nodes]:
raise Failure('gemm in graph')
f = function([a, u, X,Y], a * u + T.dot(X,Y), mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
def test_inplace0():
#should fail to insert gemm because gemm would create cycles
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
f = function([X,Y,Z,a,b, R, S, c],
[Z * (Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
def test_inplace1():
X,Y,Z,a,b = XYZab()
# with > 2 terms in the overall addition
f = function([X,Y,Z,a,b],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
......@@ -155,14 +155,14 @@ class QuadraticDenoisingAA(T.RModule):
updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients))
# INTERFACE METHODS
self.update = theano.Method(self.input, self.ncost, updates)
self.compute_cost = theano.Method(self.input, self.cost)
self.noisify = theano.Method(self.input, self.corrupted_input)
self.reconstruction = theano.Method(self.input, self.output)
self.representation = theano.Method(self.input, self.hidden)
self.reconstruction_through_noise = theano.Method(self.input, [self.corrupted_input, self.noutput])
#self.update = theano.Method(self.input, self.ncost, updates)
#self.compute_cost = theano.Method(self.input, self.cost)
#self.noisify = theano.Method(self.input, self.corrupted_input)
#self.reconstruction = theano.Method(self.input, self.output)
#self.representation = theano.Method(self.input, self.hidden)
#self.reconstruction_through_noise = theano.Method(self.input, [self.corrupted_input, self.noutput])
self.validate = theano.Method(self.input, [self.cost, self.output])
#self.validate = theano.Method(self.input, [self.cost, self.output])
def _instance_initialize(self, obj, input_size, hidden_size, seed, lr, qfilter_relscale):
"""
......@@ -291,16 +291,16 @@ class Module_Nclass(module.FancyModule):
#define the apply method
self.pred = T.argmax(linear_output, axis=1)
self.apply = module.Method([self.input], self.pred)
#self.apply = module.Method([self.input], self.pred)
self.validate = module.Method([self.input, self.targ], [self.cost, self.argmax, self.max_pr])
self.softmax_output = module.Method([self.input], self.softmax_unsupervised)
#self.validate = module.Method([self.input, self.targ], [self.cost, self.argmax, self.max_pr])
#self.softmax_output = module.Method([self.input], self.softmax_unsupervised)
if self.params:
gparams = T.grad(sum_xent, self.params)
self.update = module.Method([self.input, self.targ], sum_xent,
updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
#self.update = module.Method([self.input, self.targ], sum_xent,
#updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
class ConvolutionalMLPInstance(module.FancyModuleInstance, Loss01):
#initialize is called by Module.make
......@@ -366,11 +366,6 @@ class ConvolutionalMLP(module.FancyModule):
)
)
# to_update = []
# all_kits = []
# input_update = self.input_representations[0].update
# input_update.resolve_all()
for i in self.inputs[1:]:
self.input_representations.append(
QDAA(
......@@ -411,11 +406,17 @@ class ConvolutionalMLP(module.FancyModule):
] + self.hidden.qfilters
input_pretraining_cost = sum(i.ncost for i in self.input_representations)
hidden_pretraining_cost = self.hidden.ncost
input_pretraining_gradients = T.grad(input_pretraining_cost, input_pretraining_params)
input_pretraining_gradients = T.grad(input_pretraining_cost,
input_pretraining_params)
hidden_pretraining_gradients = T.grad(hidden_pretraining_cost, hidden_pretraining_params)
pretraining_updates = dict((p, p - self.lr * g) for p, g in zip(input_pretraining_params, input_pretraining_gradients) +
zip(hidden_pretraining_params, hidden_pretraining_gradients))
self.pretraining_update = module.Method(self.inputs, [input_pretraining_cost, hidden_pretraining_cost], pretraining_updates)
pretraining_updates = \
dict((p, p - self.lr * g) for p, g in \
zip(input_pretraining_params, input_pretraining_gradients) \
+ zip(hidden_pretraining_params, hidden_pretraining_gradients))
self.pretraining_update = module.Method(self.inputs,
[input_pretraining_cost, hidden_pretraining_cost],
pretraining_updates)
finetuning_params = \
[self.input_representations[0].w1, self.input_representations[0].b1] + self.input_representations[0].qfilters + \
......@@ -426,9 +427,8 @@ class ConvolutionalMLP(module.FancyModule):
finetuning_updates = dict((p, p - self.lr * g) for p, g in zip(finetuning_params, finetuning_gradients))
self.finetuning_update = module.Method(self.inputs + [self.targ], self.output.cost, finetuning_updates)
self.validate = module.Method(self.inputs + [self.targ], [self.output.cost, self.output.argmax, self.output.max_pr])
self.softmax_output = module.Method(self.inputs, self.output.softmax_unsupervised)
#self.validate = module.Method(self.inputs + [self.targ], [self.output.cost, self.output.argmax, self.output.max_pr])
#self.softmax_output = module.Method(self.inputs, self.output.softmax_unsupervised)
def create(window_size=3,
input_dimension=9,
......@@ -462,15 +462,21 @@ JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print 'JTEST', JTEST
theano.compile.register_optimizer('JTEST', JTEST)
if __name__ == '__main__':
optimizer = eval(sys.argv[1])
m = create(compile_mode = theano.Mode(linker='c|py', optimizer=optimizer))
prog_str = []
for i, node in enumerate(m.finetuning_update.maker.env.toposort()):
#print ' ', i, node
idx_of_node = {}
for i, node in enumerate(m.pretraining_update.maker.env.toposort()):
idx_of_node[node] = i
if False and i > -1:
print ' ', i, node, [(ii, idx_of_node.get(ii.owner, 'IN')) for ii in node.inputs]
prog_str.append(str(node))
print "PROGRAM LEN %i HASH %i"% (len(m.finetuning_update.maker.env.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
#print input_pretraining_gradients[4].owner.inputs
#print input_pretraining_gradients[4].owner.inputs[1].owner.inputs
#sys.exit()
print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.env.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
rng = N.random.RandomState(23904)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论