提交 119df018 authored 作者: Frederic Bastien's avatar Frederic Bastien
...@@ -16,6 +16,6 @@ import scalar ...@@ -16,6 +16,6 @@ import scalar
import sparse import sparse
import gradient import gradient
import elemwise import elemwise
import tensor_opt
## import scalar_opt ## import scalar_opt
## import tensor_opt
...@@ -12,7 +12,7 @@ from graph import \ ...@@ -12,7 +12,7 @@ from graph import \
Apply, Result, Constant, Value Apply, Result, Constant, Value
from link import \ from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler Linker, LocalLinker, PerformLinker, WrapLinker, Profiler
from op import \ from op import \
Op Op
......
...@@ -34,6 +34,9 @@ class MyType(Type): ...@@ -34,6 +34,9 @@ class MyType(Type):
def MyResult(name): def MyResult(name):
return Result(MyType(), None, None, name = name) return Result(MyType(), None, None, name = name)
def MyValue(data):
return graph.Value(MyType(), data = data)
class MyOp(Op): class MyOp(Op):
...@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase): ...@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase):
g.replace(tv, sx) g.replace(tv, sx)
assert g.consistent() assert g.consistent()
def test_value_repl(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, MyValue("abc"))
assert g.consistent()
def test_value_repl_2(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, transpose_view(MyValue("abc")))
assert g.consistent()
def test_misc_2(self):
x, y, z = inputs()
tv = transpose_view(x)
e = add_in_place(x, tv)
g = Env([x,y], [e], False)
assert not g.consistent()
g.replace(tv, x)
assert not g.consistent()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -122,6 +122,40 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -122,6 +122,40 @@ class _test_PerformLinker(unittest.TestCase):
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5) self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk
class _test_WrapLinker(unittest.TestCase):
def test0(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data is None)
def test1(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
th()
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data == 1.5, o[0].data)
# def test_disconnected_input_output(self): # def test_disconnected_input_output(self):
......
...@@ -788,11 +788,6 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -788,11 +788,6 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
order = env.toposort() order = env.toposort()
......
...@@ -120,6 +120,8 @@ class Env(utils.object2): ...@@ -120,6 +120,8 @@ class Env(utils.object2):
for r in results: for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r) raise TypeError("Undeclared input", r)
if not getattr(r, 'env', None) is self:
self.__setup_r__(r)
self.results.add(r) self.results.add(r)
def __import__(self, node, check = True): def __import__(self, node, check = True):
...@@ -230,7 +232,10 @@ class Env(utils.object2): ...@@ -230,7 +232,10 @@ class Env(utils.object2):
raise Exception("Cannot replace %s because it does not belong to this Env" % r) raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r) raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
assert r in self.results if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients): for node, i in list(r.clients):
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
......
...@@ -165,7 +165,7 @@ class Value(Result): ...@@ -165,7 +165,7 @@ class Value(Result):
return self.name return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type) return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self): def clone(self):
return self.__class__(self.type, self.data) return self.__class__(self.type, self.data, self.name)
def __set_owner(self, value): def __set_owner(self, value):
if value is not None: if value is not None:
raise ValueError("Value instances cannot have an owner.") raise ValueError("Value instances cannot have an owner.")
......
...@@ -166,6 +166,10 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -166,6 +166,10 @@ def map_storage(env, order, input_storage, output_storage):
class LocalLinker(Linker): class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def streamline(self, env, thunks, order, no_recycling = [], profiler = None): def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None: if profiler is None:
def f(): def f():
...@@ -187,7 +191,21 @@ class LocalLinker(Linker): ...@@ -187,7 +191,21 @@ class LocalLinker(Linker):
f.profiler = profiler f.profiler = profiler
return f return f
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise AbstractFunctionError
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
...@@ -206,11 +224,6 @@ class PerformLinker(LocalLinker): ...@@ -206,11 +224,6 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
order = env.toposort() order = env.toposort()
...@@ -242,60 +255,79 @@ class PerformLinker(LocalLinker): ...@@ -242,60 +255,79 @@ class PerformLinker(LocalLinker):
class MetaLinker(Linker): class WrapLinker(Linker):
""" """
Can run several linkers in parallel. They should all be LocalLinkers This class makes it easier to run several L{LocalLinker}s in parallel, and
and they should all return the same order. A wrapper function must offers some control over how each thunk is run.
be provided to execute the thunks, inspect the nodes, etc.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
@note:
The outputs of the first linker will be returned. The outputs of the first linker will be returned.
@note:
This linker ensures that each linker has its own storage for
inputs and outputs and intermediate results. There is no interference
between linkers.
""" """
def __init__(self, env, linkers, wrapper, no_recycling = []): def __init__(self, linkers, wrapper):
""" """
Initialize a MetaLinker. Initialize a WrapLinker.
@type linkers: list of L{LocalLinker} subclasses, whose make_all()
method returns thunks in the same order.
@param linkers: for each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper will be called like @type wrapper: lambda (i, i_node, i_thunk1, i_thunk2, ...) : None
wrapper(i, node, thunk1, thunk2, ...)
One thunk for each linker. wrapper will be executed for each operation @param wrapper: do some user-defined action for the i'th element of the
in order. program. i_thunk<n> is the thunk returned by the n'th linker. (If you
want to run the program, make sure to call the necessary thunks in this
function.)
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
""" """
self.env = env
self.linkers = linkers self.linkers = linkers
self.wrapper = wrapper self.wrapper = wrapper
self.no_recycling = no_recycling
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs): def accept(self, env, no_recycling = []):
""" """
You can pass an alternate env to use with the 'alt_env' @type env: gof.Env
option. @param env: the env which we will link
The 'wrapf' option must be a function that will be used @type no_recycling: a list of Results that belong to env.
to wrap the thunk (eg to add methods to it).
The rest of the options will be passed to all the linkers @param no_recycling: If a Result is in no_recycling, L{WrapLinker} will clear
associated with this MetaLinker. the output storage associated to it (for each linker in linkers) during
the computation to avoid reusing it.
""" """
self.env = env
self.no_recycling = no_recycling
for l in self.linkers:
l.accept(env, no_recycling)
return self
def pre(self, f, inputs, order, thunk_groups):
pass
env = kwargs.pop("alt_env", self.env) def make_thunk(self, **kwargs):
wrapf = kwargs.pop("wrapf", None)
no_recycling = self.no_recycling no_recycling = self.no_recycling
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs) make_all = [l.make_all(**kwargs) for l in self.linkers]
for linker in self.linkers])
fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all)
order_list0 = order_lists[0] order_list0 = order_lists[0]
for order_list in order_lists[1:]: for order_list in order_lists[1:]:
if not order_list0 == order_list: if not order_list0 == order_list:
raise Exception("All linkers to MetaLinker should execute operations in the same order.") raise Exception("All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0] inputs0 = input_lists[0]
outputs0 = output_lists[0] outputs0 = output_lists[0]
...@@ -321,13 +353,10 @@ class MetaLinker(Linker): ...@@ -321,13 +353,10 @@ class MetaLinker(Linker):
pre(f, [input.data for input in input_lists[0]], order, thunk_groups) pre(f, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)): for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try: try:
wrapper(f, i, node, *thunks) wrapper(i, node, *thunks)
except: except:
raise_with_op(node) raise_with_op(node)
if wrapf is not None:
f = wrapf(f)
return f, inputs0, outputs0 return f, inputs0, outputs0
......
...@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer): ...@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer):
""" """
if node.op != self.op: if node.op != self.op:
return False return False
def match(pattern, expr, u, first = False): def match(pattern, expr, u, allow_multiple_clients = False):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and len(expr.clients) > 1): if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False return False
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return False return False
for p, v in zip(pattern[1:], expr.owner.inputs): for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u) u = match(p, v, u, self.allow_multiple_clients)
if not u: if not u:
return False return False
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
try: try:
real_pattern = pattern['pattern'] real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError: except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern) raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern)
constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr): if constraint(expr):
return match(real_pattern, expr, u, False) return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else: else:
return False return False
elif isinstance(pattern, str): elif isinstance(pattern, str):
...@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer): ...@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer):
elif isinstance(pattern, str): elif isinstance(pattern, str):
return u[unify.Var(pattern)] return u[unify.Var(pattern)]
else: else:
return pattern return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True) u = match(self.in_pattern, node.out, unify.Unification(), True)
if u: if u:
...@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer): ...@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]])) return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint'])) return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions')))
else: else:
return str(pattern) return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern)) return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
......
...@@ -156,7 +156,7 @@ class numeric_grad: ...@@ -156,7 +156,7 @@ class numeric_grad:
@staticmethod @staticmethod
def abs_rel_err(a,b,eps=1.0e-10): def abs_rel_err(a,b,eps=1.0e-10):
"""Return a small number when a and b are close, relative to how big they are""" """Return a small number when a and b are close, relative to how big they are"""
return abs( (a-b) / (a+b+eps)) return abs(a-b) / (abs(a)+abs(b)+eps)
def max_err(self, g_pt): def max_err(self, g_pt):
"""Return the biggest relative error between g_pt and self.gf""" """Return the biggest relative error between g_pt and self.gf"""
......
...@@ -757,7 +757,7 @@ class Subtensor(Op): ...@@ -757,7 +757,7 @@ class Subtensor(Op):
rest = inputs[1:] rest = inputs[1:]
return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest) return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest)
def __eq__(self, others): def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list return type(self) == type(other) and self.idx_list == other.idx_list
def __hash__(self): def __hash__(self):
......
...@@ -2,6 +2,17 @@ ...@@ -2,6 +2,17 @@
import gof import gof
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
import scalar import scalar
import tensor as T
gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
allow_multiple_clients = True),
(T.dot, 'b', 'c'))),
(T.gemm, 'd', 'a', 'b', 'c', T.constant(-1.0)),
allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer): class InplaceOptimizer(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论