提交 119df018 authored 作者: Frederic Bastien's avatar Frederic Bastien
...@@ -16,6 +16,6 @@ import scalar ...@@ -16,6 +16,6 @@ import scalar
import sparse import sparse
import gradient import gradient
import elemwise import elemwise
import tensor_opt
## import scalar_opt ## import scalar_opt
## import tensor_opt
...@@ -12,7 +12,7 @@ from graph import \ ...@@ -12,7 +12,7 @@ from graph import \
Apply, Result, Constant, Value Apply, Result, Constant, Value
from link import \ from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler Linker, LocalLinker, PerformLinker, WrapLinker, Profiler
from op import \ from op import \
Op Op
......
...@@ -34,6 +34,9 @@ class MyType(Type): ...@@ -34,6 +34,9 @@ class MyType(Type):
def MyResult(name): def MyResult(name):
return Result(MyType(), None, None, name = name) return Result(MyType(), None, None, name = name)
def MyValue(data):
return graph.Value(MyType(), data = data)
class MyOp(Op): class MyOp(Op):
...@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase): ...@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase):
g.replace(tv, sx) g.replace(tv, sx)
assert g.consistent() assert g.consistent()
def test_value_repl(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, MyValue("abc"))
assert g.consistent()
def test_value_repl_2(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, transpose_view(MyValue("abc")))
assert g.consistent()
def test_misc_2(self):
x, y, z = inputs()
tv = transpose_view(x)
e = add_in_place(x, tv)
g = Env([x,y], [e], False)
assert not g.consistent()
g.replace(tv, x)
assert not g.consistent()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -122,6 +122,40 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -122,6 +122,40 @@ class _test_PerformLinker(unittest.TestCase):
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5) self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk
class _test_WrapLinker(unittest.TestCase):
def test0(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data is None)
def test1(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
th()
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data == 1.5, o[0].data)
# def test_disconnected_input_output(self): # def test_disconnected_input_output(self):
......
...@@ -788,11 +788,6 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -788,11 +788,6 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
order = env.toposort() order = env.toposort()
......
...@@ -120,6 +120,8 @@ class Env(utils.object2): ...@@ -120,6 +120,8 @@ class Env(utils.object2):
for r in results: for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r) raise TypeError("Undeclared input", r)
if not getattr(r, 'env', None) is self:
self.__setup_r__(r)
self.results.add(r) self.results.add(r)
def __import__(self, node, check = True): def __import__(self, node, check = True):
...@@ -230,7 +232,10 @@ class Env(utils.object2): ...@@ -230,7 +232,10 @@ class Env(utils.object2):
raise Exception("Cannot replace %s because it does not belong to this Env" % r) raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r) raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
assert r in self.results if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients): for node, i in list(r.clients):
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
......
...@@ -165,7 +165,7 @@ class Value(Result): ...@@ -165,7 +165,7 @@ class Value(Result):
return self.name return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type) return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self): def clone(self):
return self.__class__(self.type, self.data) return self.__class__(self.type, self.data, self.name)
def __set_owner(self, value): def __set_owner(self, value):
if value is not None: if value is not None:
raise ValueError("Value instances cannot have an owner.") raise ValueError("Value instances cannot have an owner.")
......
...@@ -166,6 +166,10 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -166,6 +166,10 @@ def map_storage(env, order, input_storage, output_storage):
class LocalLinker(Linker): class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def streamline(self, env, thunks, order, no_recycling = [], profiler = None): def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None: if profiler is None:
def f(): def f():
...@@ -187,7 +191,21 @@ class LocalLinker(Linker): ...@@ -187,7 +191,21 @@ class LocalLinker(Linker):
f.profiler = profiler f.profiler = profiler
return f return f
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise AbstractFunctionError
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
...@@ -206,11 +224,6 @@ class PerformLinker(LocalLinker): ...@@ -206,11 +224,6 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
order = env.toposort() order = env.toposort()
...@@ -242,60 +255,79 @@ class PerformLinker(LocalLinker): ...@@ -242,60 +255,79 @@ class PerformLinker(LocalLinker):
class MetaLinker(Linker): class WrapLinker(Linker):
""" """
Can run several linkers in parallel. They should all be LocalLinkers This class makes it easier to run several L{LocalLinker}s in parallel, and
and they should all return the same order. A wrapper function must offers some control over how each thunk is run.
be provided to execute the thunks, inspect the nodes, etc.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
@note:
The outputs of the first linker will be returned. The outputs of the first linker will be returned.
@note:
This linker ensures that each linker has its own storage for
inputs and outputs and intermediate results. There is no interference
between linkers.
""" """
def __init__(self, env, linkers, wrapper, no_recycling = []): def __init__(self, linkers, wrapper):
""" """
Initialize a MetaLinker. Initialize a WrapLinker.
@type linkers: list of L{LocalLinker} subclasses, whose make_all()
method returns thunks in the same order.
wrapper will be called like @param linkers: for each node in the graph, each linker will provide a
wrapper(i, node, thunk1, thunk2, ...) thunk. This class makes it possible to iterate over each linker's
program in parallel.
One thunk for each linker. wrapper will be executed for each operation @type wrapper: lambda (i, i_node, i_thunk1, i_thunk2, ...) : None
in order.
@param wrapper: do some user-defined action for the i'th element of the
program. i_thunk<n> is the thunk returned by the n'th linker. (If you
want to run the program, make sure to call the necessary thunks in this
function.)
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
""" """
self.env = env
self.linkers = linkers self.linkers = linkers
self.wrapper = wrapper self.wrapper = wrapper
self.no_recycling = no_recycling
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs): def accept(self, env, no_recycling = []):
""" """
You can pass an alternate env to use with the 'alt_env' @type env: gof.Env
option. @param env: the env which we will link
The 'wrapf' option must be a function that will be used @type no_recycling: a list of Results that belong to env.
to wrap the thunk (eg to add methods to it).
@param no_recycling: If a Result is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers) during
the computation to avoid reusing it.
The rest of the options will be passed to all the linkers
associated with this MetaLinker.
""" """
self.env = env
self.no_recycling = no_recycling
for l in self.linkers:
l.accept(env, no_recycling)
return self
def pre(self, f, inputs, order, thunk_groups):
pass
env = kwargs.pop("alt_env", self.env) def make_thunk(self, **kwargs):
wrapf = kwargs.pop("wrapf", None)
no_recycling = self.no_recycling no_recycling = self.no_recycling
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs) make_all = [l.make_all(**kwargs) for l in self.linkers]
for linker in self.linkers])
fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all)
order_list0 = order_lists[0] order_list0 = order_lists[0]
for order_list in order_lists[1:]: for order_list in order_lists[1:]:
if not order_list0 == order_list: if not order_list0 == order_list:
raise Exception("All linkers to MetaLinker should execute operations in the same order.") raise Exception("All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0] inputs0 = input_lists[0]
outputs0 = output_lists[0] outputs0 = output_lists[0]
...@@ -321,13 +353,10 @@ class MetaLinker(Linker): ...@@ -321,13 +353,10 @@ class MetaLinker(Linker):
pre(f, [input.data for input in input_lists[0]], order, thunk_groups) pre(f, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)): for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try: try:
wrapper(f, i, node, *thunks) wrapper(i, node, *thunks)
except: except:
raise_with_op(node) raise_with_op(node)
if wrapf is not None:
f = wrapf(f)
return f, inputs0, outputs0 return f, inputs0, outputs0
......
...@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer): ...@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer):
""" """
if node.op != self.op: if node.op != self.op:
return False return False
def match(pattern, expr, u, first = False): def match(pattern, expr, u, allow_multiple_clients = False):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and len(expr.clients) > 1): if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False return False
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return False return False
for p, v in zip(pattern[1:], expr.owner.inputs): for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u) u = match(p, v, u, self.allow_multiple_clients)
if not u: if not u:
return False return False
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
try: try:
real_pattern = pattern['pattern'] real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError: except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern) raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern)
constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr): if constraint(expr):
return match(real_pattern, expr, u, False) return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else: else:
return False return False
elif isinstance(pattern, str): elif isinstance(pattern, str):
...@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer): ...@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer):
elif isinstance(pattern, str): elif isinstance(pattern, str):
return u[unify.Var(pattern)] return u[unify.Var(pattern)]
else: else:
return pattern return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True) u = match(self.in_pattern, node.out, unify.Unification(), True)
if u: if u:
...@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer): ...@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]])) return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint'])) return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions')))
else: else:
return str(pattern) return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern)) return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
......
...@@ -156,7 +156,7 @@ class numeric_grad: ...@@ -156,7 +156,7 @@ class numeric_grad:
@staticmethod @staticmethod
def abs_rel_err(a,b,eps=1.0e-10): def abs_rel_err(a,b,eps=1.0e-10):
"""Return a small number when a and b are close, relative to how big they are""" """Return a small number when a and b are close, relative to how big they are"""
return abs( (a-b) / (a+b+eps)) return abs(a-b) / (abs(a)+abs(b)+eps)
def max_err(self, g_pt): def max_err(self, g_pt):
"""Return the biggest relative error between g_pt and self.gf""" """Return the biggest relative error between g_pt and self.gf"""
......
...@@ -757,7 +757,7 @@ class Subtensor(Op): ...@@ -757,7 +757,7 @@ class Subtensor(Op):
rest = inputs[1:] rest = inputs[1:]
return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest) return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest)
def __eq__(self, others): def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list return type(self) == type(other) and self.idx_list == other.idx_list
def __hash__(self): def __hash__(self):
......
...@@ -2,6 +2,17 @@ ...@@ -2,6 +2,17 @@
import gof import gof
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
import scalar import scalar
import tensor as T
gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', 'a', 'b', 'c', T.constant(-1.0)),
allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer): class InplaceOptimizer(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论