提交 119df018 authored 作者: Frederic Bastien's avatar Frederic Bastien
......@@ -16,6 +16,6 @@ import scalar
import sparse
import gradient
import elemwise
import tensor_opt
## import scalar_opt
## import tensor_opt
......@@ -12,7 +12,7 @@ from graph import \
Apply, Result, Constant, Value
from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
Linker, LocalLinker, PerformLinker, WrapLinker, Profiler
from op import \
Op
......
......@@ -34,6 +34,9 @@ class MyType(Type):
def MyResult(name):
return Result(MyType(), None, None, name = name)
def MyValue(data):
return graph.Value(MyType(), data = data)
class MyOp(Op):
......@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase):
g.replace(tv, sx)
assert g.consistent()
def test_value_repl(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, MyValue("abc"))
assert g.consistent()
def test_value_repl_2(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, transpose_view(MyValue("abc")))
assert g.consistent()
def test_misc_2(self):
x, y, z = inputs()
tv = transpose_view(x)
e = add_in_place(x, tv)
g = Env([x,y], [e], False)
assert not g.consistent()
g.replace(tv, x)
assert not g.consistent()
if __name__ == '__main__':
unittest.main()
......
......@@ -122,6 +122,40 @@ class _test_PerformLinker(unittest.TestCase):
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk
class _test_WrapLinker(unittest.TestCase):
def test0(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data is None)
def test1(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
th()
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data == 1.5, o[0].data)
# def test_disconnected_input_output(self):
......
......@@ -788,11 +788,6 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env
order = env.toposort()
......
......@@ -120,6 +120,8 @@ class Env(utils.object2):
for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r)
if not getattr(r, 'env', None) is self:
self.__setup_r__(r)
self.results.add(r)
def __import__(self, node, check = True):
......@@ -230,7 +232,10 @@ class Env(utils.object2):
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
assert r in self.results
if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients):
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
......
......@@ -165,7 +165,7 @@ class Value(Result):
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
return self.__class__(self.type, self.data)
return self.__class__(self.type, self.data, self.name)
def __set_owner(self, value):
if value is not None:
raise ValueError("Value instances cannot have an owner.")
......
......@@ -166,6 +166,10 @@ def map_storage(env, order, input_storage, output_storage):
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
......@@ -187,7 +191,21 @@ class LocalLinker(Linker):
f.profiler = profiler
return f
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise AbstractFunctionError
class PerformLinker(LocalLinker):
......@@ -206,11 +224,6 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env
order = env.toposort()
......@@ -242,60 +255,79 @@ class PerformLinker(LocalLinker):
class MetaLinker(Linker):
class WrapLinker(Linker):
"""
Can run several linkers in parallel. They should all be LocalLinkers
and they should all return the same order. A wrapper function must
be provided to execute the thunks, inspect the nodes, etc.
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
@note:
The outputs of the first linker will be returned.
@note:
This linker ensures that each linker has its own storage for
inputs and outputs and intermediate results. There is no interference
between linkers.
"""
def __init__(self, env, linkers, wrapper, no_recycling = []):
def __init__(self, linkers, wrapper):
"""
Initialize a MetaLinker.
Initialize a WrapLinker.
@type linkers: list of L{LocalLinker} subclasses, whose make_all()
method returns thunks in the same order.
wrapper will be called like
wrapper(i, node, thunk1, thunk2, ...)
@param linkers: for each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
One thunk for each linker. wrapper will be executed for each operation
in order.
@type wrapper: lambda (i, i_node, i_thunk1, i_thunk2, ...) : None
@param wrapper: do some user-defined action for the i'th element of the
program. i_thunk<n> is the thunk returned by the n'th linker. (If you
want to run the program, make sure to call the necessary thunks in this
function.)
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
self.env = env
self.linkers = linkers
self.wrapper = wrapper
self.no_recycling = no_recycling
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs):
def accept(self, env, no_recycling = []):
"""
You can pass an alternate env to use with the 'alt_env'
option.
@type env: gof.Env
@param env: the env which we will link
The 'wrapf' option must be a function that will be used
to wrap the thunk (eg to add methods to it).
@type no_recycling: a list of Results that belong to env.
@param no_recycling: If a Result is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers) during
the computation to avoid reusing it.
The rest of the options will be passed to all the linkers
associated with this MetaLinker.
"""
self.env = env
self.no_recycling = no_recycling
for l in self.linkers:
l.accept(env, no_recycling)
return self
def pre(self, f, inputs, order, thunk_groups):
pass
env = kwargs.pop("alt_env", self.env)
wrapf = kwargs.pop("wrapf", None)
def make_thunk(self, **kwargs):
no_recycling = self.no_recycling
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs)
for linker in self.linkers])
make_all = [l.make_all(**kwargs) for l in self.linkers]
fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all)
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
if not order_list0 == order_list:
raise Exception("All linkers to MetaLinker should execute operations in the same order.")
raise Exception("All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0]
outputs0 = output_lists[0]
......@@ -321,13 +353,10 @@ class MetaLinker(Linker):
pre(f, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(f, i, node, *thunks)
wrapper(i, node, *thunks)
except:
raise_with_op(node)
if wrapf is not None:
f = wrapf(f)
return f, inputs0, outputs0
......
......@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer):
"""
if node.op != self.op:
return False
def match(pattern, expr, u, first = False):
def match(pattern, expr, u, allow_multiple_clients = False):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and len(expr.clients) > 1):
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
u = match(p, v, u, self.allow_multiple_clients)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern)
constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr):
return match(real_pattern, expr, u, False)
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else:
return False
elif isinstance(pattern, str):
......@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer):
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True)
if u:
......@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions')))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
......
......@@ -156,7 +156,7 @@ class numeric_grad:
@staticmethod
def abs_rel_err(a,b,eps=1.0e-10):
"""Return a small number when a and b are close, relative to how big they are"""
return abs( (a-b) / (a+b+eps))
return abs(a-b) / (abs(a)+abs(b)+eps)
def max_err(self, g_pt):
"""Return the biggest relative error between g_pt and self.gf"""
......
......@@ -757,7 +757,7 @@ class Subtensor(Op):
rest = inputs[1:]
return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest)
def __eq__(self, others):
def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list
def __hash__(self):
......
......@@ -2,6 +2,17 @@
import gof
from elemwise import Elemwise, DimShuffle
import scalar
import tensor as T
gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', 'a', 'b', 'c', T.constant(-1.0)),
allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论