提交 bf69ab87 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

lots of stuff

上级 67b911ab
......@@ -540,7 +540,8 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to
num_grad = gradient.numeric_grad(cost_fn, pt)
symbolic_grad = grad(cost, tensor_pt,as_tensor(1.0,name='g_cost'))
#symbolic_grad = exec_grad(cost, tensor_pt,as_tensor(1.0,name='g_cost'))
symbolic_grad = grad.make_node(cost, tensor_pt).outputs
if 0:
print '-------'
print '----------'
......@@ -898,7 +899,7 @@ class T_subtensor(unittest.TestCase):
n = as_tensor(numpy.random.rand(2,3))
z = scal.constant(0)
t = n[z:,z]
gn = grad(sum(exp(t)), n)
gn = exec_grad(sum(exp(t)), n)
gval = eval_outputs([gn])
s0 = 'array([ 2.05362099, 0. , 0. ])'
s1 = 'array([ 1.55009327, 0. , 0. ])'
......@@ -908,7 +909,7 @@ class T_subtensor(unittest.TestCase):
def test_grad_0d(self):
n = as_tensor(numpy.random.rand(2,3))
t = n[1,0]
gn = grad(sum(exp(t)), n)
gn = exec_grad(sum(exp(t)), n)
gval = eval_outputs([gn])
g0 = repr(gval[0,:])
g1 = repr(gval[1,:])
......@@ -937,7 +938,7 @@ class T_Stack(unittest.TestCase):
a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]]))
b = as_tensor(numpy.array([[7, 8, 9]]))
s = vertical_stack(a, b)
ga,gb = grad(sum(vertical_stack(a,b)), [a,b])
ga,gb = exec_grad(sum(vertical_stack(a,b)), [a,b])
gval = eval_outputs([ga, gb])
self.failUnless(numpy.all(gval[0] == 1.0))
......@@ -1671,13 +1672,13 @@ class _test_grad(unittest.TestCase):
"""grad: Test passing a single result param"""
o = _test_grad.O()
a1 = o.make_node()
self.failUnless(o.gval0 is grad(a1.outputs[0], a1.inputs[0]))
self.failUnless(o.gval0 is exec_grad(a1.outputs[0], a1.inputs[0]))
def test_Nparam(self):
"""grad: Test passing multiple result params"""
o = _test_grad.O()
a1 = o.make_node()
g0,g1 = grad(a1.outputs[0], a1.inputs)
g0,g1 = exec_grad(a1.outputs[0], a1.inputs)
self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1)
......@@ -1685,13 +1686,13 @@ class _test_grad(unittest.TestCase):
"""grad: Test returning a single None from grad"""
o = _test_grad.O()
a1 = o.make_node()
self.failUnless(None is grad(a1.outputs[0], a1.outputs[1]))
self.failUnless(None is grad(a1.outputs[0], 'wtf'))
self.failUnless(None is exec_grad(a1.outputs[0], a1.outputs[1]))
self.failUnless(None is exec_grad(a1.outputs[0], 'wtf'))
def test_NNone_rval(self):
"""grad: Test returning some Nones from grad"""
o = _test_grad.O()
a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + ['wtf'])
g0,g1,g2 = exec_grad(a1.outputs[0], a1.inputs + ['wtf'])
self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1)
self.failUnless(None is g2)
......
## PENDING REWRITE OF tensor_opt.py
# import unittest
import unittest
# import gof
# from tensor_opt import *
# import tensor
# from tensor import Tensor
# from gof import Env
# from elemwise import DimShuffle
# import numpy
# import scalar_opt
import gof
from tensor_opt import *
import tensor
from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
import numpy
#import scalar_opt
# def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
# y = Tensor(broadcastable = ybc, dtype = 'float64')('y')
# z = Tensor(broadcastable = zbc, dtype = 'float64')('z')
# return x, y, z
# ds = DimShuffle
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
y = Tensor(broadcastable = ybc, dtype = 'float64')('y')
z = Tensor(broadcastable = zbc, dtype = 'float64')('z')
return x, y, z
# class _test_inplace_opt(unittest.TestCase):
......@@ -60,39 +58,45 @@
# self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
# class _test_dimshuffle_lift(unittest.TestCase):
# def test_double_transpose(self):
# x, y, z = inputs()
# e = ds(ds(x, (1, 0)), (1, 0))
# g = Env([x], [e])
# self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x))]")
# lift_dimshuffle.optimize(g)
# self.failUnless(str(g) == "[x]")
# def test_merge2(self):
# x, y, z = inputs()
# e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
# g = Env([x], [e])
# self.failUnless(str(g) == "[InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x))]", str(g))
# lift_dimshuffle.optimize(g)
# self.failUnless(str(g) == "[InplaceDimShuffle{0,1,x,x}(x)]", str(g))
# def test_elim3(self):
# x, y, z = inputs()
# e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
# g = Env([x], [e])
# self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{0,x,1}(x)))]", str(g))
# lift_dimshuffle.optimize(g)
# self.failUnless(str(g) == "[x]", str(g))
# def test_lift(self):
# x, y, z = inputs([0]*1, [0]*2, [0]*3)
# e = x + y + z
# g = Env([x, y, z], [e])
# self.failUnless(str(g) == "[Broadcast{Add}(InplaceDimShuffle{x,0,1}(Broadcast{Add}(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
# lift_dimshuffle.optimize(g)
# self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
class _test_dimshuffle_lift(unittest.TestCase):
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[x]")
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
x, y, z = inputs([False]*1, [False]*2, [False]*3)
e = x + y + z
g = Env([x, y, z], [e])
gof.ExpandMacros().optimize(g)
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
print g
lift_dimshuffle.optimize(g)
gof.ExpandMacros().optimize(g)
print g
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
# class _test_cliques(unittest.TestCase):
......@@ -185,8 +189,8 @@
# if __name__ == '__main__':
# unittest.main()
if __name__ == '__main__':
unittest.main()
......
......@@ -104,6 +104,9 @@ class FunctionFactory:
if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs, disown_inputs = disown_inputs)
gof.ExpandMacros().optimize(env)
#gof.ExpandMacros(lambda node: getattr(node.op, 'level', 0) <= 1).optimize(env)
#gof.ExpandMacros(lambda node: node.op.level == 2).optimize(env)
if None is not optimizer:
optimizer(env)
env.validate()
......
......@@ -29,11 +29,6 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ###
##################
## TODO: rule-based version of DimShuffle
## would allow for Transpose, LComplete, RComplete, etc.
## Can be optimized into DimShuffle later on.
class ShuffleRule(Macro):
"""
ABSTRACT Op - it has no perform and no c_code
......@@ -41,20 +36,30 @@ class ShuffleRule(Macro):
Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed.
"""
def __init__(self, rule = None, name = None):
level = 1
def __init__(self, rule = None, inplace = False, name = None):
if rule is not None:
self.rule = rule
self.inplace = inplace
if inplace:
self.view_map = {0: [0]}
self.name = name
def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
ib = input.type.broadcastable
return gof.Apply(self,
(input,) + models,
[Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' for x in pattern]).make_result()])
def expand(self, r):
input, models = r.owner.inputs[0], r.owner.inputs[1:]
broadcastable = [x == 'x' or ib[x] for x in pattern]).make_result()])
def expand(self, node):
input, models = node.inputs[0], node.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return DimShuffle(input.type.broadcastable, new_order)(input)
#print new_order, node.outputs[0].type, DimShuffle(input.type.broadcastable, new_order)(input).type, node.outputs[0].type == DimShuffle(input.type.broadcastable, new_order)(input).type
if list(new_order) == range(input.type.ndim) and self.inplace:
return [input]
else:
return [DimShuffle(input.type.broadcastable, new_order, self.inplace)(input)]
def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other):
......@@ -66,10 +71,13 @@ class ShuffleRule(Macro):
return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
inplace = True,
name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
inplace = True,
name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
inplace = True,
name = 'rcomplete')
......@@ -170,7 +178,7 @@ class DimShuffle(Op):
ob = []
for value in self.new_order:
if value == 'x':
ob.append(1)
ob.append(True)
else:
ob.append(ib[value])
......@@ -304,8 +312,10 @@ class Elemwise(Op):
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
target_length = max([input.type.ndim for input in inputs])
if len(inputs) > 1:
inputs = [lcomplete(input, *inputs) for input in inputs]
# args = []
# for input in inputs:
# length = input.type.ndim
......@@ -316,7 +326,7 @@ class Elemwise(Op):
# # TODO: use LComplete instead
# args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# inputs = args
try:
assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError):
......
......@@ -5,29 +5,6 @@ from gof import utils
from copy import copy
import re
def _zip(*lists):
if not lists:
return ((), ())
else:
return zip(*lists)
# x = ivector()
# y = ivector()
# e = x + y
# f = Formula(x = x, y = y, e = e)
# y = x + x
# g = Formula(x=x,y=y)
# x2 = x + x
# g = Formula(x=x, x2=x2)
class Formula(utils.object2):
def __init__(self, symtable_d = {}, **symtable_kwargs):
......@@ -87,18 +64,12 @@ class Formula(utils.object2):
################
def __rename__(self, **symequiv):
# print "~~~~~~~~~~~~~"
# print symequiv
vars = dict(self.__vars__)
for symbol, replacement in symequiv.iteritems():
if replacement is not None:
vars[replacement] = self.get(symbol)
# print vars
# print set(symequiv.keys()).difference(set(symequiv.values()))
# print set(symequiv.keys()), set(symequiv.values())
for symbol in set(symequiv.keys()).difference(set(symequiv.values())):
del vars[symbol]
# print vars
return Formula(vars)
def rename(self, **symequiv):
......@@ -174,11 +145,7 @@ class Formula(utils.object2):
strings.append("%s = %s" % (output,
pprint.pp.clone_assign(lambda pstate, r: r.name in self.__vars__ and r is not output,
pprint.LeafPrinter()).process(output)))
# strings.append("%s = %s" % (output,
# pprint.pp.process(output)))
#strings.append(str(gof.graph.as_string(self.inputs, self.outputs)))
return "\n".join(strings)
# (self.inputs + utils.difference(self.outputs, node.outputs),[output])[0]
#################
### OPERATORS ###
......@@ -253,10 +220,6 @@ def glue(*formulas):
return reduce(glue2, formulas)
import tensor as T
sep = "---------------------------"
class FormulasMetaclass(type):
def __init__(cls, name, bases, dct):
......@@ -272,402 +235,3 @@ class Formulas(utils.object2):
def __new__(cls):
return cls.__canon__.clone()
# class Test(Formulas):
# x = T.ivector()
# y = T.ivector()
# e = x + y + 21
# x = T.ivector()
# y = T.ivector()
# e = x + y + 21
# f1 = Formula(x = x, y = y, e = e)
# Test() -> f1.clone()
# f = Test()
# print f
# print f.reassign(x = T.ivector())
# print f.reassign(x = T.dvector(), y = T.dvector())
# print f.reassign(x = T.dmatrix(), y = T.dmatrix())
# class Test(Formulas):
# x = T.ivector()
# e = x + 999
# f = Test()
# print f
# print f.reassign(x = T.ivector())
# print f.reassign(x = T.dvector())
# class Layer(Formulas):
# x = T.ivector()
# y = T.ivector()
# x2 = x + y
# # print Layer()
# # print Layer() + 1
# # print Layer() + 2
# # print Layer() + 3
# print Layer() * 3
# def sigmoid(x):
# return 1.0 / (1.0 + T.exp(-x))
#class Update(Formulas):
# param = T.matrix()
# lr, cost = T.scalars(2)
# param_update = param - lr * T.sgrad(cost, param)
#class SumSqrDiff(Formulas):
# target, output = T.rows(2)
# cost = T.sum((target - output)**2)
# class Layer(Formulas):
# input, bias = T.rows(2)
# weights = T.matrix()
# input2 = T.tanh(bias + T.dot(input, weights))
# forward = Layer()*2
# g = glue(forward.rename(input3 = 'output'),
# SumSqrDiff().rename(target = 'input1'),
# *[Update().rename_regex({'param(.*)': ('%s\\1' % param.name)}) for param in forward.get_all('(weight|bias).*')])
# sg = g.__str__()
# print unicode(g)
# lr = 0.01
# def autoassociator_f(x, w, b, c):
# reconstruction = sigmoid(T.dot(sigmoid(T.dot(x, w) + b), w.T) + c)
# rec_error = T.sum((x - reconstruction)**2)
# new_w = w - lr * Th.grad(rec_error, w)
# new_b = b - lr * Th.grad(rec_error, b)
# new_c = c - lr * Th.grad(rec_error, c)
# # f = Th.Function([x, w, b, c], [reconstruction, rec_error, new_w, new_b, new_c])
# f = Th.Function([x, w, b, c], [reconstruction, rec_error, new_w, new_b, new_c], linker_cls = Th.gof.OpWiseCLinker)
# return f
# x, w = T.matrices('xw')
# b, c = T.rows('bc')
# f = autoassociator_f(x, w, b, c)
# w_val, b_val, c_val = numpy.random.rand(10, 10), numpy.random.rand(1, 10), numpy.random.rand(1, 10)
# x_storage = numpy.ndarray((1, 10))
# for i in dataset_1hot(x_storage, numpy.ndarray((1, )), 10000):
# rec, err, w_val, b_val, c_val = f(x_storage, w_val, b_val, c_val)
# if not(i % 100):
# print err
# x = T.ivector()
# y = T.ivector()
# z = x + y
# e = z - 24
# f = Formula(x = x, y = y, z = z, e = e)
# print f
# print sep
# a = T.lvector()
# b = a * a
# f2 = Formula(e = a, b = b)
# print f2
# print sep
# print glue(f, f2)
# print sep
# x1 = T.ivector()
# y1 = x1 + x1
# y2 = T.ivector()
# x2 = y2 + y2
# f1 = Formula(x=x1, y=y1)
# f2 = Formula(x=x2, y=y2)
# print f1
# print sep
# print f2
# print sep
# print glue(f1, f2)
# print sep
# x1 = T.ivector()
# z1 = T.ivector()
# y1 = x1 + z1
# w1 = x1 * x1
# x2 = T.ivector()
# e2 = T.ivector()
# z2 = e2 ** e2
# g2 = z2 + x2
# f1 = Formula(x=x1, z=z1, y=y1, w=w1)
# f2 = Formula(x=x2, e=e2, z=z2, g=g2)
# print sep
# print f1
# print sep
# print f2
# print sep
# print f1 + f2
# x1 = T.ivector()
# z1 = T.ivector()
# y1 = x1 + x1
# w1 = z1 + z1
# e2 = T.ivector()
# w2 = T.ivector()
# x2 = e2 + e2
# g2 = w2 + w2
# f1 = Formula(x=x1, z=z1, y=y1, w=w1)
# f2 = Formula(x=x2, e=e2, w=w2, g=g2)
# print sep
# print f1
# print sep
# print f2
# print sep
# print f1 + f2
# def glue2(f1, f2):
# reassign_f1 = {}
# reassign_f2 = {}
# equiv = {}
# for r1 in f1.inputs:
# name = r1.name
# try:
# r2 = f2.get(name)
# if not r1.type == r2.type:
# raise TypeError("inconsistent typing for %s: %s, %s" % (name, r1.type, r2.type))
# if name in f2.input_names:
# reassign_f2[r2] = r1
# elif name in f2.output_names:
# reassign_f1[r1] = r2
# except AttributeError:
# pass
# for r1 in f1.outputs:
# name = r1.name
# try:
# r2 = f2.get(name)
# if not r1.type == r2.type:
# raise TypeError("inconsistent typing for %s: %s, %s" % (name, r1.type, r2.type))
# if name in f2.input_names:
# reassign_f2[r2] = r1
# elif name in f2.output_names:
# raise Exception("It is not allowed for a variable to be the output of two different formulas: %s" % name)
# except AttributeError:
# pass
# print reassign_f1
# print reassign_f2
# #i0, o0 = gof.graph.clone_with_new_inputs(f1.inputrs+f2.inputrs, f1.outputrs+f2.outputrs,
# # [reassign_f1.get(name, r) for name, r in zip(f1.inputs, f1.inputrs)] +
# # [reassign_f2.get(name, r) for name, r in zip(f2.inputs, f2.inputrs)])
# #print gof.Env([x for x in i0 if x.owner is None], o0)
# #return
# ##equiv = gof.graph.clone_with_new_inputs_get_equiv(f1.inputrs, f1.outputrs, [reassign_f1.get(name, r) for name, r in zip(f1.inputs, f1.inputrs)])
# ##i1, o1 = [equiv[r] for r in f1.inputrs], [equiv[r] for r in f1.outputrs]
# ##_reassign_f2, reassign_f2 = reassign_f2, {}
# ##for name, r in _reassign_f2.items():
# ## print name, r, equiv.get(r, r) is r
# ## reassign_f2[name] = equiv.get(r, r)
# ## i1, o1 = gof.graph.clone_with_new_inputs(f1.inputs, f1.outputs, [reassign_f1.get(name, r) for name, r in zip(f1.input_names, f1.inputs)])
# ## i2, o2 = gof.graph.clone_with_new_inputs(f2.inputs, f2.outputs, [reassign_f2.get(name, r) for name, r in zip(f2.input_names, f2.inputs)])
# i1, o1 = gof.graph.clone_with_equiv(f1.inputs, f1.outputs, reassign_f1)
# vars = {}
# vars.update(zip(f1.input_names, i1))
# vars.update(zip(f1.output_names, o1))
# vars.update(zip(f2.input_names, i2))
# vars.update(zip(f2.output_names, o2))
# #print vars
# #print gof.graph.as_string(i1, o1)
# #print gof.graph.as_string(i2, o2)
# #print "a"
# #o = o1[0]
# #while o.owner is not None:
# # print o, o.owner
# # o = o.owner.inputs[0]
# #print "b", o
# return Formula(vars)
# class FormulasMetaclass(type):
# def __init__(cls, name, bases, dct):
# variables = {}
# for name, var in dct.items():
# if isinstance(var, gof.Result):
# variables[name] = var
# cls.__variables__ = variables
# cls.__canon__ = Formula(cls.__variables__)
# class Formulas(utils.object2):
# __metaclass__ = FormulasMetaclass
# def __new__(cls):
# return cls.__canon__.clone()
# class Test(Formulas):
# x = T.ivector()
# y = T.ivector()
# e = x + y
# class Test2(Formulas):
# e = T.ivector()
# x = T.ivector()
# w = e ** (x / e)
# f = Test() # + Test2()
# print f
# print sep
# print f.prefix("hey_")
# print sep
# print f.suffix("_woot")
# print sep
# print f.increment(1)
# print sep
# print f.normalize()
# print sep
# print f.increment(1).increment(1)
# print sep
# print f.increment(8).suffix("_yep")
# print sep
# print (f + 8).suffix("_yep", push_numbers = True)
# print sep
# print f.suffix("_yep", push_numbers = True)
# print sep
# print f.rename_regex({"(x|y)": "\\1\\1",
# 'e': "OUTPUT"})
# print sep
# print f + "_suffix"
# print sep
# print "prefix_" + f
# print sep
#### Usage case ####
# class Forward1(Formula):
# input, b, c = drows(3)
# w = dmatrix()
# output = dot(sigmoid(dot(w, input) + b), w.T) + c
# class Forward2(Formula):
# input, b, c = drows(3)
# w1, w2 = dmatrices(2)
# output = dot(sigmoid(dot(w1, input) + b), w2) + c
# class SumSqrError(Formula):
# target, output = drows(2)
# cost = sum((target - output)**2)
# class GradUpdate(Formula):
# lr, cost = dscalars(2)
# param = dmatrix()
# param_updated = param + lr * grad(cost, param)
# NNetUpdate = glue(Forward1(), SumSqrError(), [GradUpdate.rename(param = name) for name in ['w', 'b', 'c']])
# class Forward(Formula):
# input, w, b, c = vars(4)
# output = dot(sigmoid(dot(w, input) + b), w.T) + c
# class SumSqrError(Formula):
# target, output = vars(2)
# cost = sum((target - output)**2)
# class GradUpdate(Formula):
# lr, cost, param = vars(3)
# param_updated = param + lr * grad(cost, param)
# NNetUpdate = Forward() + SumSqrError() + [GradUpdate().rename({'param*': name}) for name in 'wbc']
# #NNetUpdate = Forward() + SumSqrError() + [GradUpdate().rename(param = name) for name in ['w', 'b', 'c']]
import link
from functools import partial
class DebugException(Exception):
pass
......
......@@ -30,14 +30,14 @@ class Optimizer:
"""
pass
def optimize(self, env):
def optimize(self, env, *args, **kwargs):
"""
This is meant as a shortcut to::
env.satisfy(opt)
opt.apply(env)
"""
self.add_requirements(env)
self.apply(env)
self.apply(env, *args, **kwargs)
def __call__(self, env):
"""
......@@ -218,8 +218,14 @@ class LocalOpKeyOptGroup(LocalOptGroup):
class ExpandMacro(LocalOptimizer):
def __init__(self, filter = None):
if filter is None:
self.filter = lambda node: True
else:
self.filter = filter
def transform(self, node):
if not isinstance(node.op, op.Macro):
if not isinstance(node.op, op.Macro) or not self.filter(node):
return False
return node.op.expand(node)
......@@ -466,7 +472,7 @@ class NavigatorOptimizer(Optimizer):
def process_node(self, env, node):
replacements = self.local_opt.transform(node)
if replacements is False:
if replacements is False or replacements is None:
return
repl_pairs = zip(node.outputs, replacements)
try:
......@@ -490,13 +496,15 @@ class TopoOptimizer(NavigatorOptimizer):
self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
def apply(self, env):
q = deque(graph.io_toposort(env.inputs, env.outputs))
def apply(self, env, start_from = None):
if start_from is None: start_from = env.outputs
q = deque(graph.io_toposort(env.inputs, start_from))
def importer(node):
q.append(node)
def pruner(node):
if node is not current_node:
q.remove(node)
try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner)
try:
......@@ -529,7 +537,8 @@ class OpKeyOptimizer(NavigatorOptimizer):
if node.op == op: q.append(node)
def pruner(node):
if node is not current_node and node.op == op:
q.remove(node)
try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner)
try:
while q:
......@@ -554,7 +563,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
### Pre-defined optimizers ###
##############################
expand_macros = TopoOptimizer(ExpandMacro(), ignore_newtrees = False)
def ExpandMacros(filter = None):
return TopoOptimizer(ExpandMacro(filter = filter),
order = 'in_to_out',
ignore_newtrees = False)
......
......@@ -83,7 +83,7 @@ class DimShufflePrinter:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, ShuffleRule):
#print r, r.owner.op
new_r = r.owner.op.expand(r)
new_r = r.owner.op.expand(r.owner)
#print new_r.owner, isinstance(new_r.owner.op, ShuffleRule)
return self.process(new_r, pstate)
elif isinstance(r.owner.op, DimShuffle):
......@@ -163,16 +163,6 @@ class PPrinter:
return cp
# class ExtendedPPrinter:
# def __init__(self, pprinter, leaf_pprinter):
# self.pprinter = pprinter
# self.leaf_pprinter = pprinter
# def process(self, r, pstate = None):
from tensor import *
from elemwise import Sum, ShuffleRule
......@@ -194,18 +184,18 @@ pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimS
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, ShuffleRule), DimShufflePrinter())
print pp.process(x + y * z)
print pp.process((x + y) * z)
print pp.process(x * (y * z))
print pp.process(x / (y / z) / x)
print pp.process((x ** y) ** z)
print pp.process(-x+y)
print pp.process(-x*y)
print pp.process(sum(x))
print pp.process(sum(x * 10))
# print pp.process(x + y * z)
# print pp.process((x + y) * z)
# print pp.process(x * (y * z))
# print pp.process(x / (y / z) / x)
# print pp.process((x ** y) ** z)
# print pp.process(-x+y)
# print pp.process(-x*y)
# print pp.process(sum(x))
# print pp.process(sum(x * 10))
a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha')
print a.type
print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a)
# a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha')
# print a.type
# print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a)
print pp.process(x / (y * z))
# print pp.process(x / (y * z))
......@@ -1142,15 +1142,19 @@ gemm = Gemm()
# Gradient
#########################
class SGrad(gof.Op):
class Grad(gof.Macro):
level = 2
def make_node(self, cost, wrt):
return Apply(self, [cost, wrt], [wrt.type()])
def expand(self, r):
cost, wrt = r.owner.inputs
return grad(cost, wrt)
sgrad = SGrad()
def grad(cost, wrt, g_cost=None):
if not isinstance(wrt, list):
wrt = [wrt]
return Apply(self, [cost] + wrt, [_wrt.type() for _wrt in wrt])
def expand(self, node):
cost, wrt = node.inputs[0], node.inputs[1:]
g = exec_grad(cost, wrt)
return g
grad = Grad()
def exec_grad(cost, wrt, g_cost=None):
"""
@type cost: L{Result}
@type wrt: L{Result} or list of L{Result}s.
......
from gof import opt, Env
import gof
from elemwise import Broadcast, DimShuffle
from gof.python25 import any, all
from elemwise import Elemwise, DimShuffle
import scalar
class InplaceOptimizer(opt.OpSpecificOptimizer):
class InplaceOptimizer(gof.Optimizer):
"""
Usage: inplace_optimizer.optimize(env)
......@@ -20,315 +18,363 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
opclass = Broadcast
def apply_on_op(self, env, op):
baseline = op.inplace_pattern
candidate_outputs = [i for i in xrange(len(op.outputs)) if i not in baseline]
candidate_inputs = [i for i in xrange(len(op.inputs)) if i not in baseline.values()]
for candidate_output in candidate_outputs:
for candidate_input in candidate_inputs:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
try:
new_op = Broadcast(op.scalar_opclass, op.inputs, inplace_pattern)
env.replace_all(dict(zip(op.outputs, new_op.outputs)))
except:
continue
candidate_inputs.remove(candidate_input)
op = new_op
baseline = inplace_pattern
break
def apply(self, env):
for node in list(env.nodes):
op = node.op
if not isinstance(op, Elemwise):
continue
baseline = op.inplace_pattern
candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline]
candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values()]
for candidate_output in candidate_outputs:
for candidate_input in candidate_inputs:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
try:
new = Elemwise(op.scalar_op, inplace_pattern).make_node(op.inputs)
env.replace_all_validate(dict(zip(node.outputs, new.outputs)))
except:
continue
candidate_inputs.remove(candidate_input)
node = new
baseline = inplace_pattern
break
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
inplace_optimizer = InplaceOptimizer()
class DimShuffleLifter(opt.Optimizer):
class DimShuffleLifter(gof.LocalOptimizer):
"""
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Broadcast operations and merges
"Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
After this transform, clusters of Broadcast operations are
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
"""
def apply(self, env):
def transform(self, node):
op = node.op
if not isinstance(op, DimShuffle):
return False
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in inode.inputs]).outputs
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
if new_order == range(len(new_order)):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
seen = set()
def lift(r):
if r in seen:
return
seen.add(r)
if env.edge(r):
return
op = r.owner
if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle):
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
if new_order == range(len(new_order)):
repl = in_op.inputs[0]
else:
repl = DimShuffle(in_op.inputs[0], new_order).out
env.replace(r, repl)
lift(repl)
return
elif isinstance(in_op, Broadcast):
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
repl = Broadcast(in_op.scalar_opclass,
[DimShuffle(input, op.new_order).out for input in in_op.inputs],
in_op.inplace_pattern).out
env.replace(r, repl)
r = repl
op = r.owner
for next_r in op.inputs:
lift(next_r)
for output in env.outputs:
lift(output)
lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False):
"""
Usage: find_cliques(env, through_broadcast = False)
lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in')
Returns a list of pairs where each pair contains a list
of inputs and a list of outputs such that Env(inputs, outputs)
contains nothing but Broadcast Ops.
If through_broadcast is False, the cliques will only be
allowed to broadcast over the inputs, which means, for
example, that vector operations will not be mixed with
matrix operations.
"""
def seek_from(r):
# walks through the graph until it encounters a
# non-Broadcast operation or (if through_broadcast
# is False) a Result which needs to be broadcasted.
# class DimShuffleLifter(opt.Optimizer):
# """
# Usage: lift_dimshuffle.optimize(env)
# "Lifts" DimShuffle through Broadcast operations and merges
# consecutive DimShuffles. Basically, applies the following
# transformations on the whole graph:
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
# After this transform, clusters of Broadcast operations are
# void of DimShuffle operations.
# """
# def apply(self, env):
# seen = set()
op = r.owner
if env.edge(r) \
or not isinstance(op, Broadcast) \
or len(op.outputs) > 1:
# todo: handle multiple-output broadcast ops
# (needs to update the clique's outputs)
return None
ret = set()
if not through_broadcast:
# check each dimension over all the inputs - if the broadcastable
# fields are not all 0 or all 1 for a particular dimension, then
# broadcasting will be performed along it on the inputs where the
# value is 1 and we will stop.
if any(any(bc) and not all(bc)
for bc in zip(*[input.broadcastable for input in op.inputs])):
ret.update(op.inputs)
return ret
# def lift(r):
# if r in seen:
# return
# seen.add(r)
# if env.edge(r):
# return
# op = r.owner
# if isinstance(op, DimShuffle):
# in_op = op.inputs[0].owner
# if isinstance(in_op, DimShuffle):
# # DimShuffle(DimShuffle(x)) => DimShuffle(x)
# new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
# if new_order == range(len(new_order)):
# repl = in_op.inputs[0]
# else:
# repl = DimShuffle(in_op.inputs[0], new_order).out
# env.replace(r, repl)
# lift(repl)
# return
# elif isinstance(in_op, Broadcast):
# # DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# repl = Broadcast(in_op.scalar_opclass,
# [DimShuffle(input, op.new_order).out for input in in_op.inputs],
# in_op.inplace_pattern).out
# env.replace(r, repl)
# r = repl
# op = r.owner
# for next_r in op.inputs:
# lift(next_r)
# for output in env.outputs:
# lift(output)
# lift_dimshuffle = DimShuffleLifter()
# def find_cliques(env, through_broadcast = False):
# """
# Usage: find_cliques(env, through_broadcast = False)
# Returns a list of pairs where each pair contains a list
# of inputs and a list of outputs such that Env(inputs, outputs)
# contains nothing but Broadcast Ops.
# If through_broadcast is False, the cliques will only be
# allowed to broadcast over the inputs, which means, for
# example, that vector operations will not be mixed with
# matrix operations.
# """
# def seek_from(r):
# # walks through the graph until it encounters a
# # non-Broadcast operation or (if through_broadcast
# # is False) a Result which needs to be broadcasted.
for input in op.inputs:
res = seek_from(input)
if res is None:
# input is a leaf of our search
ret.add(input)
else:
ret.update(res)
# op = r.owner
# if env.edge(r) \
# or not isinstance(op, Broadcast) \
# or len(op.outputs) > 1:
# # todo: handle multiple-output broadcast ops
# # (needs to update the clique's outputs)
# return None
# ret = set()
# if not through_broadcast:
# # check each dimension over all the inputs - if the broadcastable
# # fields are not all 0 or all 1 for a particular dimension, then
# # broadcasting will be performed along it on the inputs where the
# # value is 1 and we will stop.
# if any(any(bc) and not all(bc)
# for bc in zip(*[input.broadcastable for input in op.inputs])):
# ret.update(op.inputs)
# return ret
# for input in op.inputs:
# res = seek_from(input)
# if res is None:
# # input is a leaf of our search
# ret.add(input)
# else:
# ret.update(res)
return ret
# return ret
cliques = []
def find_cliques_helper(r):
if env.edge(r):
return
clique_inputs = seek_from(r)
if clique_inputs is None:
# Not in a clique, keep going
op = r.owner
if op is not None:
for input in op.inputs:
find_cliques_helper(input)
else:
# We found a clique, add it to the list and
# jump to the leaves.
cliques.append((clique_inputs, [r]))
for input in clique_inputs:
find_cliques_helper(input)
for output in env.outputs:
find_cliques_helper(output)
# todo: merge the cliques if possible
return cliques
class CliqueOptimizer(opt.Optimizer):
"""
Usage: CliqueOptimizer(through_broadcast = False,
scalar_optimizer = None,
make_composite = False).optimize(env)
Finds cliques of Broadcast operations in the env and does either
or both of two things:
# cliques = []
# def find_cliques_helper(r):
# if env.edge(r):
# return
# clique_inputs = seek_from(r)
# if clique_inputs is None:
# # Not in a clique, keep going
# op = r.owner
# if op is not None:
# for input in op.inputs:
# find_cliques_helper(input)
# else:
# # We found a clique, add it to the list and
# # jump to the leaves.
# cliques.append((clique_inputs, [r]))
# for input in clique_inputs:
# find_cliques_helper(input)
# for output in env.outputs:
# find_cliques_helper(output)
# # todo: merge the cliques if possible
# return cliques
# class CliqueOptimizer(opt.Optimizer):
# """
# Usage: CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = False).optimize(env)
# Finds cliques of Broadcast operations in the env and does either
# or both of two things:
* Apply scalar_optimizer on the clique as if the clique was a
group of scalar operations. scalar_optimizer can be any optimization
which applies on scalars. If it is None, no optimization is done.
* Replace the clique with a single Op, optimized to perform the
computations properly. If make_composite is False, no such replacement
is done.
Note: it is recommended to run the lift_dimshuffle optimization before
this one.
"""
def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
self.through_broadcast = through_broadcast
self.scalar_optimizer = scalar_optimizer
self.make_composite = make_composite
def apply(self, env):
if self.scalar_optimizer is None and not self.make_composite:
# there's nothing to do with the cliques...
return
# * Apply scalar_optimizer on the clique as if the clique was a
# group of scalar operations. scalar_optimizer can be any optimization
# which applies on scalars. If it is None, no optimization is done.
# * Replace the clique with a single Op, optimized to perform the
# computations properly. If make_composite is False, no such replacement
# is done.
# Note: it is recommended to run the lift_dimshuffle optimization before
# this one.
# """
# def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
# self.through_broadcast = through_broadcast
# self.scalar_optimizer = scalar_optimizer
# self.make_composite = make_composite
# def apply(self, env):
# if self.scalar_optimizer is None and not self.make_composite:
# # there's nothing to do with the cliques...
# return
cliques = find_cliques(env, self.through_broadcast)
opt = self.scalar_optimizer
def build_scalar_clique(r, env, equiv):
# Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# structure and equivalent operations. equiv contains the mapping.
if r in equiv:
return equiv[r]
op = r.owner
if env.edge(r):
# For each leave we make a Scalar of the corresponding dtype
s = scalar.Scalar(dtype = r.dtype)
_r = r
if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
_r = r.owner.inputs[0]
if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
and _r.broadcastable == ():
# If we have a constant tensor we map it to a constant scalar.
s.data = _r.data
s.constant = True
equiv[r] = s
return s
s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
equiv[op] = s_op
for output, s_output in zip(op.outputs, s_op.outputs):
equiv[output] = s_output
return equiv[r]
for c_in, c_out in cliques:
equiv = dict()
g = Env(c_in, c_out)
for output in c_out:
build_scalar_clique(output, g, equiv)
s_g = Env([equiv[r] for r in g.inputs],
[equiv[r] for r in g.outputs])
if opt is not None:
equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
for k, v in equiv.items():
equiv2[v] = k
def transform(op, equiv):
# We get a scalar op and we return an equivalent op on tensors.
return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
opt.optimize(s_g)
if self.make_composite:
def follow_inplace(r):
# Tries to find the earliest r2 in g such that r destroys r2
# If no such r2 is found, returns None
op = r.owner
if op is None or r in g.inputs or r in g.orphans():
return None
assert isinstance(op, Broadcast)
destroyed = op.destroy_map().get(r, None)
if destroyed is None:
return None
else:
r2 = destroyed[0]
ret = follow_inplace(r2)
if ret is None:
return r2
else:
return ret
inplace_pattern = {}
for i, output in enumerate(g.outputs):
destroyed = follow_inplace(output)
if destroyed is not None and destroyed in g.inputs:
# we transfer the inplace operation only if it is
# an input that is destroyed
inplace_pattern[i] = g.inputs.index(destroyed)
C = scalar.composite(s_g.inputs, s_g.outputs)
ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
def sync_to(target, equiv, transform):
"""
Usage: sync_to(target, equiv, transform)
* target: an Env
* equiv: a dictionary that maps results and ops to results and ops
in target
* transform: a function that takes (op, equiv) as inputs and
returns a new op.
# cliques = find_cliques(env, self.through_broadcast)
# opt = self.scalar_optimizer
# def build_scalar_clique(r, env, equiv):
# # Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# # structure and equivalent operations. equiv contains the mapping.
# if r in equiv:
# return equiv[r]
# op = r.owner
# if env.edge(r):
# # For each leave we make a Scalar of the corresponding dtype
# s = scalar.Scalar(dtype = r.dtype)
# _r = r
# if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
# _r = r.owner.inputs[0]
# if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
# and _r.broadcastable == ():
# # If we have a constant tensor we map it to a constant scalar.
# s.data = _r.data
# s.constant = True
# equiv[r] = s
# return s
# s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
# equiv[op] = s_op
# for output, s_output in zip(op.outputs, s_op.outputs):
# equiv[output] = s_output
# return equiv[r]
# for c_in, c_out in cliques:
# equiv = dict()
# g = Env(c_in, c_out)
# for output in c_out:
# build_scalar_clique(output, g, equiv)
# s_g = Env([equiv[r] for r in g.inputs],
# [equiv[r] for r in g.outputs])
# if opt is not None:
# equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
# for k, v in equiv.items():
# equiv2[v] = k
# def transform(op, equiv):
# # We get a scalar op and we return an equivalent op on tensors.
# return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
# s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
# opt.optimize(s_g)
# if self.make_composite:
# def follow_inplace(r):
# # Tries to find the earliest r2 in g such that r destroys r2
# # If no such r2 is found, returns None
# op = r.owner
# if op is None or r in g.inputs or r in g.orphans():
# return None
# assert isinstance(op, Broadcast)
# destroyed = op.destroy_map().get(r, None)
# if destroyed is None:
# return None
# else:
# r2 = destroyed[0]
# ret = follow_inplace(r2)
# if ret is None:
# return r2
# else:
# return ret
# inplace_pattern = {}
# for i, output in enumerate(g.outputs):
# destroyed = follow_inplace(output)
# if destroyed is not None and destroyed in g.inputs:
# # we transfer the inplace operation only if it is
# # an input that is destroyed
# inplace_pattern[i] = g.inputs.index(destroyed)
# C = scalar.composite(s_g.inputs, s_g.outputs)
# ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
# env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
# def sync_to(target, equiv, transform):
# """
# Usage: sync_to(target, equiv, transform)
# * target: an Env
# * equiv: a dictionary that maps results and ops to results and ops
# in target
# * transform: a function that takes (op, equiv) as inputs and
# returns a new op.
Returns a Feature that can be added to an Env and mirrors all
modifications to that env with modifications to the target env.
"""
class Synchronize(gof.Listener, gof.Constraint):
def __init__(self, source):
self.source = source
self.target = target
self.equiv = equiv
self.transform = transform
self.inconsistencies = []
def on_import(self, op1):
if op1 not in self.equiv:
op2 = self.transform(op1, self.equiv)
self.equiv[op1] = op2
for o1, o2 in zip(op1.outputs, op2.outputs):
self.equiv[o1] = o2
def on_prune(self, op1):
if op1 in self.equiv:
op2 = self.equiv[op1]
del self.equiv[op1]
for o1, o2 in zip(op1.outputs, op2.outputs):
del self.equiv[o1]
def on_rewire(self, clients1, r1, new_r1):
if (new_r1, r1) in self.inconsistencies:
self.inconsistencies.remove((new_r1, r1))
return
if not self.source.clients(r1):
try:
target.replace(self.equiv[r1], self.equiv[new_r1])
except:
self.inconsistencies.append((r1, new_r1))
def validate(self):
if self.inconsistencies:
raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
return True
return Synchronize
# Returns a Feature that can be added to an Env and mirrors all
# modifications to that env with modifications to the target env.
# """
# class Synchronize(gof.Listener, gof.Constraint):
# def __init__(self, source):
# self.source = source
# self.target = target
# self.equiv = equiv
# self.transform = transform
# self.inconsistencies = []
# def on_import(self, op1):
# if op1 not in self.equiv:
# op2 = self.transform(op1, self.equiv)
# self.equiv[op1] = op2
# for o1, o2 in zip(op1.outputs, op2.outputs):
# self.equiv[o1] = o2
# def on_prune(self, op1):
# if op1 in self.equiv:
# op2 = self.equiv[op1]
# del self.equiv[op1]
# for o1, o2 in zip(op1.outputs, op2.outputs):
# del self.equiv[o1]
# def on_rewire(self, clients1, r1, new_r1):
# if (new_r1, r1) in self.inconsistencies:
# self.inconsistencies.remove((new_r1, r1))
# return
# if not self.source.clients(r1):
# try:
# target.replace(self.equiv[r1], self.equiv[new_r1])
# except:
# self.inconsistencies.append((r1, new_r1))
# def validate(self):
# if self.inconsistencies:
# raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
# return True
# return Synchronize
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论