提交 119df018 authored 作者: Frederic Bastien's avatar Frederic Bastien
......@@ -16,6 +16,6 @@ import scalar
import sparse
import gradient
import elemwise
import tensor_opt
## import scalar_opt
## import tensor_opt
......@@ -12,7 +12,7 @@ from graph import \
Apply, Result, Constant, Value
from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
Linker, LocalLinker, PerformLinker, WrapLinker, Profiler
from op import \
Op
......
......@@ -34,6 +34,9 @@ class MyType(Type):
def MyResult(name):
return Result(MyType(), None, None, name = name)
def MyValue(data):
return graph.Value(MyType(), data = data)
class MyOp(Op):
......@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase):
g.replace(tv, sx)
assert g.consistent()
def test_value_repl(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, MyValue("abc"))
assert g.consistent()
def test_value_repl_2(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, transpose_view(MyValue("abc")))
assert g.consistent()
def test_misc_2(self):
x, y, z = inputs()
tv = transpose_view(x)
e = add_in_place(x, tv)
g = Env([x,y], [e], False)
assert not g.consistent()
g.replace(tv, x)
assert not g.consistent()
if __name__ == '__main__':
unittest.main()
......
......@@ -122,6 +122,40 @@ class _test_PerformLinker(unittest.TestCase):
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk
class _test_WrapLinker(unittest.TestCase):
def test0(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data is None)
def test1(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
th()
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data == 1.5, o[0].data)
# def test_disconnected_input_output(self):
......
......@@ -788,11 +788,6 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env
order = env.toposort()
......
......@@ -120,6 +120,8 @@ class Env(utils.object2):
for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r)
if not getattr(r, 'env', None) is self:
self.__setup_r__(r)
self.results.add(r)
def __import__(self, node, check = True):
......@@ -230,7 +232,10 @@ class Env(utils.object2):
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
assert r in self.results
if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients):
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
......
......@@ -165,7 +165,7 @@ class Value(Result):
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
return self.__class__(self.type, self.data)
return self.__class__(self.type, self.data, self.name)
def __set_owner(self, value):
if value is not None:
raise ValueError("Value instances cannot have an owner.")
......
......@@ -166,6 +166,10 @@ def map_storage(env, order, input_storage, output_storage):
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
......@@ -187,7 +191,21 @@ class LocalLinker(Linker):
f.profiler = profiler
return f
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise AbstractFunctionError
class PerformLinker(LocalLinker):
......@@ -206,11 +224,6 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env
order = env.toposort()
......@@ -242,60 +255,79 @@ class PerformLinker(LocalLinker):
class MetaLinker(Linker):
class WrapLinker(Linker):
"""
Can run several linkers in parallel. They should all be LocalLinkers
and they should all return the same order. A wrapper function must
be provided to execute the thunks, inspect the nodes, etc.
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
@note:
The outputs of the first linker will be returned.
@note:
This linker ensures that each linker has its own storage for
inputs and outputs and intermediate results. There is no interference
between linkers.
"""
def __init__(self, env, linkers, wrapper, no_recycling = []):
def __init__(self, linkers, wrapper):
"""
Initialize a MetaLinker.
Initialize a WrapLinker.
@type linkers: list of L{LocalLinker} subclasses, whose make_all()
method returns thunks in the same order.
@param linkers: for each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper will be called like
wrapper(i, node, thunk1, thunk2, ...)
@type wrapper: lambda (i, i_node, i_thunk1, i_thunk2, ...) : None
One thunk for each linker. wrapper will be executed for each operation
in order.
@param wrapper: do some user-defined action for the i'th element of the
program. i_thunk<n> is the thunk returned by the n'th linker. (If you
want to run the program, make sure to call the necessary thunks in this
function.)
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
self.env = env
self.linkers = linkers
self.wrapper = wrapper
self.no_recycling = no_recycling
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs):
def accept(self, env, no_recycling = []):
"""
You can pass an alternate env to use with the 'alt_env'
option.
@type env: gof.Env
@param env: the env which we will link
The 'wrapf' option must be a function that will be used
to wrap the thunk (eg to add methods to it).
@type no_recycling: a list of Results that belong to env.
The rest of the options will be passed to all the linkers
associated with this MetaLinker.
@param no_recycling: If a Result is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers) during
the computation to avoid reusing it.
"""
self.env = env
self.no_recycling = no_recycling
for l in self.linkers:
l.accept(env, no_recycling)
return self
def pre(self, f, inputs, order, thunk_groups):
pass
env = kwargs.pop("alt_env", self.env)
wrapf = kwargs.pop("wrapf", None)
def make_thunk(self, **kwargs):
no_recycling = self.no_recycling
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs)
for linker in self.linkers])
make_all = [l.make_all(**kwargs) for l in self.linkers]
fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all)
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
if not order_list0 == order_list:
raise Exception("All linkers to MetaLinker should execute operations in the same order.")
raise Exception("All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0]
outputs0 = output_lists[0]
......@@ -321,13 +353,10 @@ class MetaLinker(Linker):
pre(f, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(f, i, node, *thunks)
wrapper(i, node, *thunks)
except:
raise_with_op(node)
if wrapf is not None:
f = wrapf(f)
return f, inputs0, outputs0
......
......@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer):
"""
if node.op != self.op:
return False
def match(pattern, expr, u, first = False):
def match(pattern, expr, u, allow_multiple_clients = False):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and len(expr.clients) > 1):
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
u = match(p, v, u, self.allow_multiple_clients)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern)
constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr):
return match(real_pattern, expr, u, False)
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else:
return False
elif isinstance(pattern, str):
......@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer):
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True)
if u:
......@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions')))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
......
......@@ -156,7 +156,7 @@ class numeric_grad:
@staticmethod
def abs_rel_err(a,b,eps=1.0e-10):
"""Return a small number when a and b are close, relative to how big they are"""
return abs( (a-b) / (a+b+eps))
return abs(a-b) / (abs(a)+abs(b)+eps)
def max_err(self, g_pt):
"""Return the biggest relative error between g_pt and self.gf"""
......
......@@ -757,7 +757,7 @@ class Subtensor(Op):
rest = inputs[1:]
return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest)
def __eq__(self, others):
def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list
def __hash__(self):
......
......@@ -2,6 +2,17 @@
import gof
from elemwise import Elemwise, DimShuffle
import scalar
import tensor as T
gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', 'a', 'b', 'c', T.constant(-1.0)),
allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论