提交 bf69ab87 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

lots of stuff

上级 67b911ab
...@@ -540,7 +540,8 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to ...@@ -540,7 +540,8 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to
num_grad = gradient.numeric_grad(cost_fn, pt) num_grad = gradient.numeric_grad(cost_fn, pt)
symbolic_grad = grad(cost, tensor_pt,as_tensor(1.0,name='g_cost')) #symbolic_grad = exec_grad(cost, tensor_pt,as_tensor(1.0,name='g_cost'))
symbolic_grad = grad.make_node(cost, tensor_pt).outputs
if 0: if 0:
print '-------' print '-------'
print '----------' print '----------'
...@@ -898,7 +899,7 @@ class T_subtensor(unittest.TestCase): ...@@ -898,7 +899,7 @@ class T_subtensor(unittest.TestCase):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
z = scal.constant(0) z = scal.constant(0)
t = n[z:,z] t = n[z:,z]
gn = grad(sum(exp(t)), n) gn = exec_grad(sum(exp(t)), n)
gval = eval_outputs([gn]) gval = eval_outputs([gn])
s0 = 'array([ 2.05362099, 0. , 0. ])' s0 = 'array([ 2.05362099, 0. , 0. ])'
s1 = 'array([ 1.55009327, 0. , 0. ])' s1 = 'array([ 1.55009327, 0. , 0. ])'
...@@ -908,7 +909,7 @@ class T_subtensor(unittest.TestCase): ...@@ -908,7 +909,7 @@ class T_subtensor(unittest.TestCase):
def test_grad_0d(self): def test_grad_0d(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
t = n[1,0] t = n[1,0]
gn = grad(sum(exp(t)), n) gn = exec_grad(sum(exp(t)), n)
gval = eval_outputs([gn]) gval = eval_outputs([gn])
g0 = repr(gval[0,:]) g0 = repr(gval[0,:])
g1 = repr(gval[1,:]) g1 = repr(gval[1,:])
...@@ -937,7 +938,7 @@ class T_Stack(unittest.TestCase): ...@@ -937,7 +938,7 @@ class T_Stack(unittest.TestCase):
a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]])) a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]]))
b = as_tensor(numpy.array([[7, 8, 9]])) b = as_tensor(numpy.array([[7, 8, 9]]))
s = vertical_stack(a, b) s = vertical_stack(a, b)
ga,gb = grad(sum(vertical_stack(a,b)), [a,b]) ga,gb = exec_grad(sum(vertical_stack(a,b)), [a,b])
gval = eval_outputs([ga, gb]) gval = eval_outputs([ga, gb])
self.failUnless(numpy.all(gval[0] == 1.0)) self.failUnless(numpy.all(gval[0] == 1.0))
...@@ -1671,13 +1672,13 @@ class _test_grad(unittest.TestCase): ...@@ -1671,13 +1672,13 @@ class _test_grad(unittest.TestCase):
"""grad: Test passing a single result param""" """grad: Test passing a single result param"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
self.failUnless(o.gval0 is grad(a1.outputs[0], a1.inputs[0])) self.failUnless(o.gval0 is exec_grad(a1.outputs[0], a1.inputs[0]))
def test_Nparam(self): def test_Nparam(self):
"""grad: Test passing multiple result params""" """grad: Test passing multiple result params"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1 = grad(a1.outputs[0], a1.inputs) g0,g1 = exec_grad(a1.outputs[0], a1.inputs)
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1) self.failUnless(o.gval1 is g1)
...@@ -1685,13 +1686,13 @@ class _test_grad(unittest.TestCase): ...@@ -1685,13 +1686,13 @@ class _test_grad(unittest.TestCase):
"""grad: Test returning a single None from grad""" """grad: Test returning a single None from grad"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
self.failUnless(None is grad(a1.outputs[0], a1.outputs[1])) self.failUnless(None is exec_grad(a1.outputs[0], a1.outputs[1]))
self.failUnless(None is grad(a1.outputs[0], 'wtf')) self.failUnless(None is exec_grad(a1.outputs[0], 'wtf'))
def test_NNone_rval(self): def test_NNone_rval(self):
"""grad: Test returning some Nones from grad""" """grad: Test returning some Nones from grad"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + ['wtf']) g0,g1,g2 = exec_grad(a1.outputs[0], a1.inputs + ['wtf'])
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1) self.failUnless(o.gval1 is g1)
self.failUnless(None is g2) self.failUnless(None is g2)
......
## PENDING REWRITE OF tensor_opt.py ## PENDING REWRITE OF tensor_opt.py
# import unittest import unittest
# import gof import gof
# from tensor_opt import * from tensor_opt import *
# import tensor import tensor
# from tensor import Tensor from tensor import Tensor
# from gof import Env from gof import Env
# from elemwise import DimShuffle from elemwise import DimShuffle
# import numpy import numpy
# import scalar_opt #import scalar_opt
# def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# x = Tensor(broadcastable = xbc, dtype = 'float64')('x') x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
# y = Tensor(broadcastable = ybc, dtype = 'float64')('y') y = Tensor(broadcastable = ybc, dtype = 'float64')('y')
# z = Tensor(broadcastable = zbc, dtype = 'float64')('z') z = Tensor(broadcastable = zbc, dtype = 'float64')('z')
# return x, y, z return x, y, z
# ds = DimShuffle
# class _test_inplace_opt(unittest.TestCase): # class _test_inplace_opt(unittest.TestCase):
...@@ -60,39 +58,45 @@ ...@@ -60,39 +58,45 @@
# self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]") # self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
# class _test_dimshuffle_lift(unittest.TestCase): ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
# def test_double_transpose(self): class _test_dimshuffle_lift(unittest.TestCase):
# x, y, z = inputs()
# e = ds(ds(x, (1, 0)), (1, 0)) def test_double_transpose(self):
# g = Env([x], [e]) x, y, z = inputs()
# self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x))]") e = ds(ds(x, (1, 0)), (1, 0))
# lift_dimshuffle.optimize(g) g = Env([x], [e])
# self.failUnless(str(g) == "[x]") self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g)
# def test_merge2(self): self.failUnless(str(g) == "[x]")
# x, y, z = inputs()
# e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) def test_merge2(self):
# g = Env([x], [e]) x, y, z = inputs()
# self.failUnless(str(g) == "[InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x))]", str(g)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
# lift_dimshuffle.optimize(g) g = Env([x], [e])
# self.failUnless(str(g) == "[InplaceDimShuffle{0,1,x,x}(x)]", str(g)) self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g)
# def test_elim3(self): self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
# x, y, z = inputs()
# e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) def test_elim3(self):
# g = Env([x], [e]) x, y, z = inputs()
# self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{0,x,1}(x)))]", str(g)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
# lift_dimshuffle.optimize(g) g = Env([x], [e])
# self.failUnless(str(g) == "[x]", str(g)) self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
# def test_lift(self): self.failUnless(str(g) == "[x]", str(g))
# x, y, z = inputs([0]*1, [0]*2, [0]*3)
# e = x + y + z def test_lift(self):
# g = Env([x, y, z], [e]) x, y, z = inputs([False]*1, [False]*2, [False]*3)
# self.failUnless(str(g) == "[Broadcast{Add}(InplaceDimShuffle{x,0,1}(Broadcast{Add}(InplaceDimShuffle{x,0}(x), y)), z)]", str(g)) e = x + y + z
# lift_dimshuffle.optimize(g) g = Env([x, y, z], [e])
# self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) gof.ExpandMacros().optimize(g)
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
print g
lift_dimshuffle.optimize(g)
gof.ExpandMacros().optimize(g)
print g
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
# class _test_cliques(unittest.TestCase): # class _test_cliques(unittest.TestCase):
...@@ -185,8 +189,8 @@ ...@@ -185,8 +189,8 @@
# if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() unittest.main()
......
...@@ -104,6 +104,9 @@ class FunctionFactory: ...@@ -104,6 +104,9 @@ class FunctionFactory:
if not isinstance(r, gof.Result): if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r) raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs, disown_inputs = disown_inputs) env = std_env(inputs, outputs, disown_inputs = disown_inputs)
gof.ExpandMacros().optimize(env)
#gof.ExpandMacros(lambda node: getattr(node.op, 'level', 0) <= 1).optimize(env)
#gof.ExpandMacros(lambda node: node.op.level == 2).optimize(env)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
env.validate() env.validate()
......
...@@ -29,11 +29,6 @@ def TensorConstant(*inputs, **kwargs): ...@@ -29,11 +29,6 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ### ### DimShuffle ###
################## ##################
## TODO: rule-based version of DimShuffle
## would allow for Transpose, LComplete, RComplete, etc.
## Can be optimized into DimShuffle later on.
class ShuffleRule(Macro): class ShuffleRule(Macro):
""" """
ABSTRACT Op - it has no perform and no c_code ABSTRACT Op - it has no perform and no c_code
...@@ -41,20 +36,30 @@ class ShuffleRule(Macro): ...@@ -41,20 +36,30 @@ class ShuffleRule(Macro):
Apply ExpandMacros to this node to obtain Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed. an equivalent DimShuffle which can be performed.
""" """
def __init__(self, rule = None, name = None): level = 1
def __init__(self, rule = None, inplace = False, name = None):
if rule is not None: if rule is not None:
self.rule = rule self.rule = rule
self.inplace = inplace
if inplace:
self.view_map = {0: [0]}
self.name = name self.name = name
def make_node(self, input, *models): def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models)) pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
ib = input.type.broadcastable
return gof.Apply(self, return gof.Apply(self,
(input,) + models, (input,) + models,
[Tensor(dtype = input.type.dtype, [Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' for x in pattern]).make_result()]) broadcastable = [x == 'x' or ib[x] for x in pattern]).make_result()])
def expand(self, r): def expand(self, node):
input, models = r.owner.inputs[0], r.owner.inputs[1:] input, models = node.inputs[0], node.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models)) new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return DimShuffle(input.type.broadcastable, new_order)(input) #print new_order, node.outputs[0].type, DimShuffle(input.type.broadcastable, new_order)(input).type, node.outputs[0].type == DimShuffle(input.type.broadcastable, new_order)(input).type
if list(new_order) == range(input.type.ndim) and self.inplace:
return [input]
else:
return [DimShuffle(input.type.broadcastable, new_order, self.inplace)(input)]
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other): def __hash__(self, other):
...@@ -66,10 +71,13 @@ class ShuffleRule(Macro): ...@@ -66,10 +71,13 @@ class ShuffleRule(Macro):
return "ShuffleRule{%s}" % self.role return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1), _transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
inplace = True,
name = 'transpose') name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)), lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
inplace = True,
name = 'lcomplete') name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)), rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
inplace = True,
name = 'rcomplete') name = 'rcomplete')
...@@ -170,7 +178,7 @@ class DimShuffle(Op): ...@@ -170,7 +178,7 @@ class DimShuffle(Op):
ob = [] ob = []
for value in self.new_order: for value in self.new_order:
if value == 'x': if value == 'x':
ob.append(1) ob.append(True)
else: else:
ob.append(ib[value]) ob.append(ib[value])
...@@ -304,8 +312,10 @@ class Elemwise(Op): ...@@ -304,8 +312,10 @@ class Elemwise(Op):
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs]) shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
target_length = max([input.type.ndim for input in inputs]) target_length = max([input.type.ndim for input in inputs])
if len(inputs) > 1: if len(inputs) > 1:
inputs = [lcomplete(input, *inputs) for input in inputs] inputs = [lcomplete(input, *inputs) for input in inputs]
# args = [] # args = []
# for input in inputs: # for input in inputs:
# length = input.type.ndim # length = input.type.ndim
...@@ -316,7 +326,7 @@ class Elemwise(Op): ...@@ -316,7 +326,7 @@ class Elemwise(Op):
# # TODO: use LComplete instead # # TODO: use LComplete instead
# args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input)) # args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# inputs = args # inputs = args
try: try:
assert len(set([len(input.type.broadcastable) for input in inputs])) == 1 assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError): except (AssertionError, AttributeError):
......
...@@ -5,29 +5,6 @@ from gof import utils ...@@ -5,29 +5,6 @@ from gof import utils
from copy import copy from copy import copy
import re import re
def _zip(*lists):
if not lists:
return ((), ())
else:
return zip(*lists)
# x = ivector()
# y = ivector()
# e = x + y
# f = Formula(x = x, y = y, e = e)
# y = x + x
# g = Formula(x=x,y=y)
# x2 = x + x
# g = Formula(x=x, x2=x2)
class Formula(utils.object2): class Formula(utils.object2):
def __init__(self, symtable_d = {}, **symtable_kwargs): def __init__(self, symtable_d = {}, **symtable_kwargs):
...@@ -87,18 +64,12 @@ class Formula(utils.object2): ...@@ -87,18 +64,12 @@ class Formula(utils.object2):
################ ################
def __rename__(self, **symequiv): def __rename__(self, **symequiv):
# print "~~~~~~~~~~~~~"
# print symequiv
vars = dict(self.__vars__) vars = dict(self.__vars__)
for symbol, replacement in symequiv.iteritems(): for symbol, replacement in symequiv.iteritems():
if replacement is not None: if replacement is not None:
vars[replacement] = self.get(symbol) vars[replacement] = self.get(symbol)
# print vars
# print set(symequiv.keys()).difference(set(symequiv.values()))
# print set(symequiv.keys()), set(symequiv.values())
for symbol in set(symequiv.keys()).difference(set(symequiv.values())): for symbol in set(symequiv.keys()).difference(set(symequiv.values())):
del vars[symbol] del vars[symbol]
# print vars
return Formula(vars) return Formula(vars)
def rename(self, **symequiv): def rename(self, **symequiv):
...@@ -174,11 +145,7 @@ class Formula(utils.object2): ...@@ -174,11 +145,7 @@ class Formula(utils.object2):
strings.append("%s = %s" % (output, strings.append("%s = %s" % (output,
pprint.pp.clone_assign(lambda pstate, r: r.name in self.__vars__ and r is not output, pprint.pp.clone_assign(lambda pstate, r: r.name in self.__vars__ and r is not output,
pprint.LeafPrinter()).process(output))) pprint.LeafPrinter()).process(output)))
# strings.append("%s = %s" % (output,
# pprint.pp.process(output)))
#strings.append(str(gof.graph.as_string(self.inputs, self.outputs)))
return "\n".join(strings) return "\n".join(strings)
# (self.inputs + utils.difference(self.outputs, node.outputs),[output])[0]
################# #################
### OPERATORS ### ### OPERATORS ###
...@@ -253,10 +220,6 @@ def glue(*formulas): ...@@ -253,10 +220,6 @@ def glue(*formulas):
return reduce(glue2, formulas) return reduce(glue2, formulas)
import tensor as T
sep = "---------------------------"
class FormulasMetaclass(type): class FormulasMetaclass(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
...@@ -272,402 +235,3 @@ class Formulas(utils.object2): ...@@ -272,402 +235,3 @@ class Formulas(utils.object2):
def __new__(cls): def __new__(cls):
return cls.__canon__.clone() return cls.__canon__.clone()
# class Test(Formulas):
# x = T.ivector()
# y = T.ivector()
# e = x + y + 21
# x = T.ivector()
# y = T.ivector()
# e = x + y + 21
# f1 = Formula(x = x, y = y, e = e)
# Test() -> f1.clone()
# f = Test()
# print f
# print f.reassign(x = T.ivector())
# print f.reassign(x = T.dvector(), y = T.dvector())
# print f.reassign(x = T.dmatrix(), y = T.dmatrix())
# class Test(Formulas):
# x = T.ivector()
# e = x + 999
# f = Test()
# print f
# print f.reassign(x = T.ivector())
# print f.reassign(x = T.dvector())
# class Layer(Formulas):
# x = T.ivector()
# y = T.ivector()
# x2 = x + y
# # print Layer()
# # print Layer() + 1
# # print Layer() + 2
# # print Layer() + 3
# print Layer() * 3
# def sigmoid(x):
# return 1.0 / (1.0 + T.exp(-x))
#class Update(Formulas):
# param = T.matrix()
# lr, cost = T.scalars(2)
# param_update = param - lr * T.sgrad(cost, param)
#class SumSqrDiff(Formulas):
# target, output = T.rows(2)
# cost = T.sum((target - output)**2)
# class Layer(Formulas):
# input, bias = T.rows(2)
# weights = T.matrix()
# input2 = T.tanh(bias + T.dot(input, weights))
# forward = Layer()*2
# g = glue(forward.rename(input3 = 'output'),
# SumSqrDiff().rename(target = 'input1'),
# *[Update().rename_regex({'param(.*)': ('%s\\1' % param.name)}) for param in forward.get_all('(weight|bias).*')])
# sg = g.__str__()
# print unicode(g)
# lr = 0.01
# def autoassociator_f(x, w, b, c):
# reconstruction = sigmoid(T.dot(sigmoid(T.dot(x, w) + b), w.T) + c)
# rec_error = T.sum((x - reconstruction)**2)
# new_w = w - lr * Th.grad(rec_error, w)
# new_b = b - lr * Th.grad(rec_error, b)
# new_c = c - lr * Th.grad(rec_error, c)
# # f = Th.Function([x, w, b, c], [reconstruction, rec_error, new_w, new_b, new_c])
# f = Th.Function([x, w, b, c], [reconstruction, rec_error, new_w, new_b, new_c], linker_cls = Th.gof.OpWiseCLinker)
# return f
# x, w = T.matrices('xw')
# b, c = T.rows('bc')
# f = autoassociator_f(x, w, b, c)
# w_val, b_val, c_val = numpy.random.rand(10, 10), numpy.random.rand(1, 10), numpy.random.rand(1, 10)
# x_storage = numpy.ndarray((1, 10))
# for i in dataset_1hot(x_storage, numpy.ndarray((1, )), 10000):
# rec, err, w_val, b_val, c_val = f(x_storage, w_val, b_val, c_val)
# if not(i % 100):
# print err
# x = T.ivector()
# y = T.ivector()
# z = x + y
# e = z - 24
# f = Formula(x = x, y = y, z = z, e = e)
# print f
# print sep
# a = T.lvector()
# b = a * a
# f2 = Formula(e = a, b = b)
# print f2
# print sep
# print glue(f, f2)
# print sep
# x1 = T.ivector()
# y1 = x1 + x1
# y2 = T.ivector()
# x2 = y2 + y2
# f1 = Formula(x=x1, y=y1)
# f2 = Formula(x=x2, y=y2)
# print f1
# print sep
# print f2
# print sep
# print glue(f1, f2)
# print sep
# x1 = T.ivector()
# z1 = T.ivector()
# y1 = x1 + z1
# w1 = x1 * x1
# x2 = T.ivector()
# e2 = T.ivector()
# z2 = e2 ** e2
# g2 = z2 + x2
# f1 = Formula(x=x1, z=z1, y=y1, w=w1)
# f2 = Formula(x=x2, e=e2, z=z2, g=g2)
# print sep
# print f1
# print sep
# print f2
# print sep
# print f1 + f2
# x1 = T.ivector()
# z1 = T.ivector()
# y1 = x1 + x1
# w1 = z1 + z1
# e2 = T.ivector()
# w2 = T.ivector()
# x2 = e2 + e2
# g2 = w2 + w2
# f1 = Formula(x=x1, z=z1, y=y1, w=w1)
# f2 = Formula(x=x2, e=e2, w=w2, g=g2)
# print sep
# print f1
# print sep
# print f2
# print sep
# print f1 + f2
# def glue2(f1, f2):
# reassign_f1 = {}
# reassign_f2 = {}
# equiv = {}
# for r1 in f1.inputs:
# name = r1.name
# try:
# r2 = f2.get(name)
# if not r1.type == r2.type:
# raise TypeError("inconsistent typing for %s: %s, %s" % (name, r1.type, r2.type))
# if name in f2.input_names:
# reassign_f2[r2] = r1
# elif name in f2.output_names:
# reassign_f1[r1] = r2
# except AttributeError:
# pass
# for r1 in f1.outputs:
# name = r1.name
# try:
# r2 = f2.get(name)
# if not r1.type == r2.type:
# raise TypeError("inconsistent typing for %s: %s, %s" % (name, r1.type, r2.type))
# if name in f2.input_names:
# reassign_f2[r2] = r1
# elif name in f2.output_names:
# raise Exception("It is not allowed for a variable to be the output of two different formulas: %s" % name)
# except AttributeError:
# pass
# print reassign_f1
# print reassign_f2
# #i0, o0 = gof.graph.clone_with_new_inputs(f1.inputrs+f2.inputrs, f1.outputrs+f2.outputrs,
# # [reassign_f1.get(name, r) for name, r in zip(f1.inputs, f1.inputrs)] +
# # [reassign_f2.get(name, r) for name, r in zip(f2.inputs, f2.inputrs)])
# #print gof.Env([x for x in i0 if x.owner is None], o0)
# #return
# ##equiv = gof.graph.clone_with_new_inputs_get_equiv(f1.inputrs, f1.outputrs, [reassign_f1.get(name, r) for name, r in zip(f1.inputs, f1.inputrs)])
# ##i1, o1 = [equiv[r] for r in f1.inputrs], [equiv[r] for r in f1.outputrs]
# ##_reassign_f2, reassign_f2 = reassign_f2, {}
# ##for name, r in _reassign_f2.items():
# ## print name, r, equiv.get(r, r) is r
# ## reassign_f2[name] = equiv.get(r, r)
# ## i1, o1 = gof.graph.clone_with_new_inputs(f1.inputs, f1.outputs, [reassign_f1.get(name, r) for name, r in zip(f1.input_names, f1.inputs)])
# ## i2, o2 = gof.graph.clone_with_new_inputs(f2.inputs, f2.outputs, [reassign_f2.get(name, r) for name, r in zip(f2.input_names, f2.inputs)])
# i1, o1 = gof.graph.clone_with_equiv(f1.inputs, f1.outputs, reassign_f1)
# vars = {}
# vars.update(zip(f1.input_names, i1))
# vars.update(zip(f1.output_names, o1))
# vars.update(zip(f2.input_names, i2))
# vars.update(zip(f2.output_names, o2))
# #print vars
# #print gof.graph.as_string(i1, o1)
# #print gof.graph.as_string(i2, o2)
# #print "a"
# #o = o1[0]
# #while o.owner is not None:
# # print o, o.owner
# # o = o.owner.inputs[0]
# #print "b", o
# return Formula(vars)
# class FormulasMetaclass(type):
# def __init__(cls, name, bases, dct):
# variables = {}
# for name, var in dct.items():
# if isinstance(var, gof.Result):
# variables[name] = var
# cls.__variables__ = variables
# cls.__canon__ = Formula(cls.__variables__)
# class Formulas(utils.object2):
# __metaclass__ = FormulasMetaclass
# def __new__(cls):
# return cls.__canon__.clone()
# class Test(Formulas):
# x = T.ivector()
# y = T.ivector()
# e = x + y
# class Test2(Formulas):
# e = T.ivector()
# x = T.ivector()
# w = e ** (x / e)
# f = Test() # + Test2()
# print f
# print sep
# print f.prefix("hey_")
# print sep
# print f.suffix("_woot")
# print sep
# print f.increment(1)
# print sep
# print f.normalize()
# print sep
# print f.increment(1).increment(1)
# print sep
# print f.increment(8).suffix("_yep")
# print sep
# print (f + 8).suffix("_yep", push_numbers = True)
# print sep
# print f.suffix("_yep", push_numbers = True)
# print sep
# print f.rename_regex({"(x|y)": "\\1\\1",
# 'e': "OUTPUT"})
# print sep
# print f + "_suffix"
# print sep
# print "prefix_" + f
# print sep
#### Usage case ####
# class Forward1(Formula):
# input, b, c = drows(3)
# w = dmatrix()
# output = dot(sigmoid(dot(w, input) + b), w.T) + c
# class Forward2(Formula):
# input, b, c = drows(3)
# w1, w2 = dmatrices(2)
# output = dot(sigmoid(dot(w1, input) + b), w2) + c
# class SumSqrError(Formula):
# target, output = drows(2)
# cost = sum((target - output)**2)
# class GradUpdate(Formula):
# lr, cost = dscalars(2)
# param = dmatrix()
# param_updated = param + lr * grad(cost, param)
# NNetUpdate = glue(Forward1(), SumSqrError(), [GradUpdate.rename(param = name) for name in ['w', 'b', 'c']])
# class Forward(Formula):
# input, w, b, c = vars(4)
# output = dot(sigmoid(dot(w, input) + b), w.T) + c
# class SumSqrError(Formula):
# target, output = vars(2)
# cost = sum((target - output)**2)
# class GradUpdate(Formula):
# lr, cost, param = vars(3)
# param_updated = param + lr * grad(cost, param)
# NNetUpdate = Forward() + SumSqrError() + [GradUpdate().rename({'param*': name}) for name in 'wbc']
# #NNetUpdate = Forward() + SumSqrError() + [GradUpdate().rename(param = name) for name in ['w', 'b', 'c']]
import link import link
from functools import partial from functools import partial
class DebugException(Exception): class DebugException(Exception):
pass pass
......
...@@ -30,14 +30,14 @@ class Optimizer: ...@@ -30,14 +30,14 @@ class Optimizer:
""" """
pass pass
def optimize(self, env): def optimize(self, env, *args, **kwargs):
""" """
This is meant as a shortcut to:: This is meant as a shortcut to::
env.satisfy(opt) env.satisfy(opt)
opt.apply(env) opt.apply(env)
""" """
self.add_requirements(env) self.add_requirements(env)
self.apply(env) self.apply(env, *args, **kwargs)
def __call__(self, env): def __call__(self, env):
""" """
...@@ -218,8 +218,14 @@ class LocalOpKeyOptGroup(LocalOptGroup): ...@@ -218,8 +218,14 @@ class LocalOpKeyOptGroup(LocalOptGroup):
class ExpandMacro(LocalOptimizer): class ExpandMacro(LocalOptimizer):
def __init__(self, filter = None):
if filter is None:
self.filter = lambda node: True
else:
self.filter = filter
def transform(self, node): def transform(self, node):
if not isinstance(node.op, op.Macro): if not isinstance(node.op, op.Macro) or not self.filter(node):
return False return False
return node.op.expand(node) return node.op.expand(node)
...@@ -466,7 +472,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -466,7 +472,7 @@ class NavigatorOptimizer(Optimizer):
def process_node(self, env, node): def process_node(self, env, node):
replacements = self.local_opt.transform(node) replacements = self.local_opt.transform(node)
if replacements is False: if replacements is False or replacements is None:
return return
repl_pairs = zip(node.outputs, replacements) repl_pairs = zip(node.outputs, replacements)
try: try:
...@@ -490,13 +496,15 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -490,13 +496,15 @@ class TopoOptimizer(NavigatorOptimizer):
self.order = order self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
def apply(self, env): def apply(self, env, start_from = None):
q = deque(graph.io_toposort(env.inputs, env.outputs)) if start_from is None: start_from = env.outputs
q = deque(graph.io_toposort(env.inputs, start_from))
def importer(node): def importer(node):
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
q.remove(node) try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
...@@ -529,7 +537,8 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -529,7 +537,8 @@ class OpKeyOptimizer(NavigatorOptimizer):
if node.op == op: q.append(node) if node.op == op: q.append(node)
def pruner(node): def pruner(node):
if node is not current_node and node.op == op: if node is not current_node and node.op == op:
q.remove(node) try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
while q: while q:
...@@ -554,7 +563,10 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -554,7 +563,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
### Pre-defined optimizers ### ### Pre-defined optimizers ###
############################## ##############################
expand_macros = TopoOptimizer(ExpandMacro(), ignore_newtrees = False) def ExpandMacros(filter = None):
return TopoOptimizer(ExpandMacro(filter = filter),
order = 'in_to_out',
ignore_newtrees = False)
......
...@@ -83,7 +83,7 @@ class DimShufflePrinter: ...@@ -83,7 +83,7 @@ class DimShufflePrinter:
raise TypeError("Can only print DimShuffle.") raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, ShuffleRule): elif isinstance(r.owner.op, ShuffleRule):
#print r, r.owner.op #print r, r.owner.op
new_r = r.owner.op.expand(r) new_r = r.owner.op.expand(r.owner)
#print new_r.owner, isinstance(new_r.owner.op, ShuffleRule) #print new_r.owner, isinstance(new_r.owner.op, ShuffleRule)
return self.process(new_r, pstate) return self.process(new_r, pstate)
elif isinstance(r.owner.op, DimShuffle): elif isinstance(r.owner.op, DimShuffle):
...@@ -163,16 +163,6 @@ class PPrinter: ...@@ -163,16 +163,6 @@ class PPrinter:
return cp return cp
# class ExtendedPPrinter:
# def __init__(self, pprinter, leaf_pprinter):
# self.pprinter = pprinter
# self.leaf_pprinter = pprinter
# def process(self, r, pstate = None):
from tensor import * from tensor import *
from elemwise import Sum, ShuffleRule from elemwise import Sum, ShuffleRule
...@@ -194,18 +184,18 @@ pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimS ...@@ -194,18 +184,18 @@ pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimS
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, ShuffleRule), DimShufflePrinter()) pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, ShuffleRule), DimShufflePrinter())
print pp.process(x + y * z) # print pp.process(x + y * z)
print pp.process((x + y) * z) # print pp.process((x + y) * z)
print pp.process(x * (y * z)) # print pp.process(x * (y * z))
print pp.process(x / (y / z) / x) # print pp.process(x / (y / z) / x)
print pp.process((x ** y) ** z) # print pp.process((x ** y) ** z)
print pp.process(-x+y) # print pp.process(-x+y)
print pp.process(-x*y) # print pp.process(-x*y)
print pp.process(sum(x)) # print pp.process(sum(x))
print pp.process(sum(x * 10)) # print pp.process(sum(x * 10))
a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha') # a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha')
print a.type # print a.type
print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a) # print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a)
print pp.process(x / (y * z)) # print pp.process(x / (y * z))
...@@ -1142,15 +1142,19 @@ gemm = Gemm() ...@@ -1142,15 +1142,19 @@ gemm = Gemm()
# Gradient # Gradient
######################### #########################
class SGrad(gof.Op): class Grad(gof.Macro):
level = 2
def make_node(self, cost, wrt): def make_node(self, cost, wrt):
return Apply(self, [cost, wrt], [wrt.type()]) if not isinstance(wrt, list):
def expand(self, r): wrt = [wrt]
cost, wrt = r.owner.inputs return Apply(self, [cost] + wrt, [_wrt.type() for _wrt in wrt])
return grad(cost, wrt) def expand(self, node):
sgrad = SGrad() cost, wrt = node.inputs[0], node.inputs[1:]
g = exec_grad(cost, wrt)
def grad(cost, wrt, g_cost=None): return g
grad = Grad()
def exec_grad(cost, wrt, g_cost=None):
""" """
@type cost: L{Result} @type cost: L{Result}
@type wrt: L{Result} or list of L{Result}s. @type wrt: L{Result} or list of L{Result}s.
......
from gof import opt, Env
import gof import gof
from elemwise import Broadcast, DimShuffle from elemwise import Elemwise, DimShuffle
from gof.python25 import any, all
import scalar import scalar
class InplaceOptimizer(opt.OpSpecificOptimizer): class InplaceOptimizer(gof.Optimizer):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -20,315 +18,363 @@ class InplaceOptimizer(opt.OpSpecificOptimizer): ...@@ -20,315 +18,363 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
x + y + z -> x += y += z x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
""" """
opclass = Broadcast
def apply_on_op(self, env, op): def apply(self, env):
baseline = op.inplace_pattern for node in list(env.nodes):
candidate_outputs = [i for i in xrange(len(op.outputs)) if i not in baseline] op = node.op
candidate_inputs = [i for i in xrange(len(op.inputs)) if i not in baseline.values()] if not isinstance(op, Elemwise):
for candidate_output in candidate_outputs: continue
for candidate_input in candidate_inputs: baseline = op.inplace_pattern
inplace_pattern = dict(baseline, **{candidate_output: candidate_input}) candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline]
try: candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values()]
new_op = Broadcast(op.scalar_opclass, op.inputs, inplace_pattern) for candidate_output in candidate_outputs:
env.replace_all(dict(zip(op.outputs, new_op.outputs))) for candidate_input in candidate_inputs:
except: inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
continue try:
candidate_inputs.remove(candidate_input) new = Elemwise(op.scalar_op, inplace_pattern).make_node(op.inputs)
op = new_op env.replace_all_validate(dict(zip(node.outputs, new.outputs)))
baseline = inplace_pattern except:
break continue
candidate_inputs.remove(candidate_input)
node = new
baseline = inplace_pattern
break
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
inplace_optimizer = InplaceOptimizer() inplace_optimizer = InplaceOptimizer()
class DimShuffleLifter(opt.Optimizer): class DimShuffleLifter(gof.LocalOptimizer):
""" """
Usage: lift_dimshuffle.optimize(env) Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Broadcast operations and merges "Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following consecutive DimShuffles. Basically, applies the following
transformations on the whole graph: transformations on the whole graph:
DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y)) DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x) DimShuffle(DimShuffle(x)) => DimShuffle(x)
After this transform, clusters of Broadcast operations are After this transform, clusters of Elemwise operations are
void of DimShuffle operations. void of DimShuffle operations.
""" """
def apply(self, env): def transform(self, node):
op = node.op
if not isinstance(op, DimShuffle):
return False
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in inode.inputs]).outputs
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
if new_order == range(len(new_order)):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
seen = set() lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in')
def lift(r):
if r in seen:
return
seen.add(r)
if env.edge(r):
return
op = r.owner
if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle):
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
if new_order == range(len(new_order)):
repl = in_op.inputs[0]
else:
repl = DimShuffle(in_op.inputs[0], new_order).out
env.replace(r, repl)
lift(repl)
return
elif isinstance(in_op, Broadcast):
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
repl = Broadcast(in_op.scalar_opclass,
[DimShuffle(input, op.new_order).out for input in in_op.inputs],
in_op.inplace_pattern).out
env.replace(r, repl)
r = repl
op = r.owner
for next_r in op.inputs:
lift(next_r)
for output in env.outputs:
lift(output)
lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False):
"""
Usage: find_cliques(env, through_broadcast = False)
Returns a list of pairs where each pair contains a list
of inputs and a list of outputs such that Env(inputs, outputs)
contains nothing but Broadcast Ops.
If through_broadcast is False, the cliques will only be
allowed to broadcast over the inputs, which means, for
example, that vector operations will not be mixed with
matrix operations.
"""
def seek_from(r):
# walks through the graph until it encounters a # class DimShuffleLifter(opt.Optimizer):
# non-Broadcast operation or (if through_broadcast # """
# is False) a Result which needs to be broadcasted. # Usage: lift_dimshuffle.optimize(env)
# "Lifts" DimShuffle through Broadcast operations and merges
# consecutive DimShuffles. Basically, applies the following
# transformations on the whole graph:
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
# After this transform, clusters of Broadcast operations are
# void of DimShuffle operations.
# """
# def apply(self, env):
# seen = set()
op = r.owner # def lift(r):
if env.edge(r) \ # if r in seen:
or not isinstance(op, Broadcast) \ # return
or len(op.outputs) > 1: # seen.add(r)
# todo: handle multiple-output broadcast ops # if env.edge(r):
# (needs to update the clique's outputs) # return
return None # op = r.owner
# if isinstance(op, DimShuffle):
ret = set() # in_op = op.inputs[0].owner
# if isinstance(in_op, DimShuffle):
if not through_broadcast: # # DimShuffle(DimShuffle(x)) => DimShuffle(x)
# check each dimension over all the inputs - if the broadcastable # new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
# fields are not all 0 or all 1 for a particular dimension, then # if new_order == range(len(new_order)):
# broadcasting will be performed along it on the inputs where the # repl = in_op.inputs[0]
# value is 1 and we will stop. # else:
if any(any(bc) and not all(bc) # repl = DimShuffle(in_op.inputs[0], new_order).out
for bc in zip(*[input.broadcastable for input in op.inputs])): # env.replace(r, repl)
ret.update(op.inputs) # lift(repl)
return ret # return
# elif isinstance(in_op, Broadcast):
# # DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# repl = Broadcast(in_op.scalar_opclass,
# [DimShuffle(input, op.new_order).out for input in in_op.inputs],
# in_op.inplace_pattern).out
# env.replace(r, repl)
# r = repl
# op = r.owner
# for next_r in op.inputs:
# lift(next_r)
# for output in env.outputs:
# lift(output)
# lift_dimshuffle = DimShuffleLifter()
# def find_cliques(env, through_broadcast = False):
# """
# Usage: find_cliques(env, through_broadcast = False)
# Returns a list of pairs where each pair contains a list
# of inputs and a list of outputs such that Env(inputs, outputs)
# contains nothing but Broadcast Ops.
# If through_broadcast is False, the cliques will only be
# allowed to broadcast over the inputs, which means, for
# example, that vector operations will not be mixed with
# matrix operations.
# """
# def seek_from(r):
# # walks through the graph until it encounters a
# # non-Broadcast operation or (if through_broadcast
# # is False) a Result which needs to be broadcasted.
for input in op.inputs: # op = r.owner
res = seek_from(input) # if env.edge(r) \
if res is None: # or not isinstance(op, Broadcast) \
# input is a leaf of our search # or len(op.outputs) > 1:
ret.add(input) # # todo: handle multiple-output broadcast ops
else: # # (needs to update the clique's outputs)
ret.update(res) # return None
# ret = set()
# if not through_broadcast:
# # check each dimension over all the inputs - if the broadcastable
# # fields are not all 0 or all 1 for a particular dimension, then
# # broadcasting will be performed along it on the inputs where the
# # value is 1 and we will stop.
# if any(any(bc) and not all(bc)
# for bc in zip(*[input.broadcastable for input in op.inputs])):
# ret.update(op.inputs)
# return ret
# for input in op.inputs:
# res = seek_from(input)
# if res is None:
# # input is a leaf of our search
# ret.add(input)
# else:
# ret.update(res)
return ret # return ret
cliques = [] # cliques = []
def find_cliques_helper(r): # def find_cliques_helper(r):
if env.edge(r): # if env.edge(r):
return # return
clique_inputs = seek_from(r) # clique_inputs = seek_from(r)
if clique_inputs is None: # if clique_inputs is None:
# Not in a clique, keep going # # Not in a clique, keep going
op = r.owner # op = r.owner
if op is not None: # if op is not None:
for input in op.inputs: # for input in op.inputs:
find_cliques_helper(input) # find_cliques_helper(input)
else: # else:
# We found a clique, add it to the list and # # We found a clique, add it to the list and
# jump to the leaves. # # jump to the leaves.
cliques.append((clique_inputs, [r])) # cliques.append((clique_inputs, [r]))
for input in clique_inputs: # for input in clique_inputs:
find_cliques_helper(input) # find_cliques_helper(input)
for output in env.outputs: # for output in env.outputs:
find_cliques_helper(output) # find_cliques_helper(output)
# todo: merge the cliques if possible # # todo: merge the cliques if possible
return cliques # return cliques
class CliqueOptimizer(opt.Optimizer): # class CliqueOptimizer(opt.Optimizer):
""" # """
Usage: CliqueOptimizer(through_broadcast = False, # Usage: CliqueOptimizer(through_broadcast = False,
scalar_optimizer = None, # scalar_optimizer = None,
make_composite = False).optimize(env) # make_composite = False).optimize(env)
Finds cliques of Broadcast operations in the env and does either # Finds cliques of Broadcast operations in the env and does either
or both of two things: # or both of two things:
* Apply scalar_optimizer on the clique as if the clique was a # * Apply scalar_optimizer on the clique as if the clique was a
group of scalar operations. scalar_optimizer can be any optimization # group of scalar operations. scalar_optimizer can be any optimization
which applies on scalars. If it is None, no optimization is done. # which applies on scalars. If it is None, no optimization is done.
* Replace the clique with a single Op, optimized to perform the # * Replace the clique with a single Op, optimized to perform the
computations properly. If make_composite is False, no such replacement # computations properly. If make_composite is False, no such replacement
is done. # is done.
Note: it is recommended to run the lift_dimshuffle optimization before # Note: it is recommended to run the lift_dimshuffle optimization before
this one. # this one.
""" # """
def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False): # def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
self.through_broadcast = through_broadcast # self.through_broadcast = through_broadcast
self.scalar_optimizer = scalar_optimizer # self.scalar_optimizer = scalar_optimizer
self.make_composite = make_composite # self.make_composite = make_composite
def apply(self, env): # def apply(self, env):
if self.scalar_optimizer is None and not self.make_composite: # if self.scalar_optimizer is None and not self.make_composite:
# there's nothing to do with the cliques... # # there's nothing to do with the cliques...
return # return
cliques = find_cliques(env, self.through_broadcast) # cliques = find_cliques(env, self.through_broadcast)
opt = self.scalar_optimizer # opt = self.scalar_optimizer
def build_scalar_clique(r, env, equiv): # def build_scalar_clique(r, env, equiv):
# Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same # # Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# structure and equivalent operations. equiv contains the mapping. # # structure and equivalent operations. equiv contains the mapping.
if r in equiv: # if r in equiv:
return equiv[r] # return equiv[r]
op = r.owner # op = r.owner
if env.edge(r): # if env.edge(r):
# For each leave we make a Scalar of the corresponding dtype # # For each leave we make a Scalar of the corresponding dtype
s = scalar.Scalar(dtype = r.dtype) # s = scalar.Scalar(dtype = r.dtype)
_r = r # _r = r
if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order): # if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
_r = r.owner.inputs[0] # _r = r.owner.inputs[0]
if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \ # if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
and _r.broadcastable == (): # and _r.broadcastable == ():
# If we have a constant tensor we map it to a constant scalar. # # If we have a constant tensor we map it to a constant scalar.
s.data = _r.data # s.data = _r.data
s.constant = True # s.constant = True
equiv[r] = s # equiv[r] = s
return s # return s
s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs]) # s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
equiv[op] = s_op # equiv[op] = s_op
for output, s_output in zip(op.outputs, s_op.outputs): # for output, s_output in zip(op.outputs, s_op.outputs):
equiv[output] = s_output # equiv[output] = s_output
return equiv[r] # return equiv[r]
for c_in, c_out in cliques: # for c_in, c_out in cliques:
equiv = dict() # equiv = dict()
g = Env(c_in, c_out) # g = Env(c_in, c_out)
for output in c_out: # for output in c_out:
build_scalar_clique(output, g, equiv) # build_scalar_clique(output, g, equiv)
s_g = Env([equiv[r] for r in g.inputs], # s_g = Env([equiv[r] for r in g.inputs],
[equiv[r] for r in g.outputs]) # [equiv[r] for r in g.outputs])
if opt is not None: # if opt is not None:
equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op # equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
for k, v in equiv.items(): # for k, v in equiv.items():
equiv2[v] = k # equiv2[v] = k
def transform(op, equiv): # def transform(op, equiv):
# We get a scalar op and we return an equivalent op on tensors. # # We get a scalar op and we return an equivalent op on tensors.
return Broadcast(op.__class__, [equiv[input] for input in op.inputs]) # return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g # s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
opt.optimize(s_g) # opt.optimize(s_g)
if self.make_composite: # if self.make_composite:
def follow_inplace(r): # def follow_inplace(r):
# Tries to find the earliest r2 in g such that r destroys r2 # # Tries to find the earliest r2 in g such that r destroys r2
# If no such r2 is found, returns None # # If no such r2 is found, returns None
op = r.owner # op = r.owner
if op is None or r in g.inputs or r in g.orphans(): # if op is None or r in g.inputs or r in g.orphans():
return None # return None
assert isinstance(op, Broadcast) # assert isinstance(op, Broadcast)
destroyed = op.destroy_map().get(r, None) # destroyed = op.destroy_map().get(r, None)
if destroyed is None: # if destroyed is None:
return None # return None
else: # else:
r2 = destroyed[0] # r2 = destroyed[0]
ret = follow_inplace(r2) # ret = follow_inplace(r2)
if ret is None: # if ret is None:
return r2 # return r2
else: # else:
return ret # return ret
inplace_pattern = {} # inplace_pattern = {}
for i, output in enumerate(g.outputs): # for i, output in enumerate(g.outputs):
destroyed = follow_inplace(output) # destroyed = follow_inplace(output)
if destroyed is not None and destroyed in g.inputs: # if destroyed is not None and destroyed in g.inputs:
# we transfer the inplace operation only if it is # # we transfer the inplace operation only if it is
# an input that is destroyed # # an input that is destroyed
inplace_pattern[i] = g.inputs.index(destroyed) # inplace_pattern[i] = g.inputs.index(destroyed)
C = scalar.composite(s_g.inputs, s_g.outputs) # C = scalar.composite(s_g.inputs, s_g.outputs)
ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern) # ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs))) # env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
def sync_to(target, equiv, transform): # def sync_to(target, equiv, transform):
""" # """
Usage: sync_to(target, equiv, transform) # Usage: sync_to(target, equiv, transform)
* target: an Env # * target: an Env
* equiv: a dictionary that maps results and ops to results and ops # * equiv: a dictionary that maps results and ops to results and ops
in target # in target
* transform: a function that takes (op, equiv) as inputs and # * transform: a function that takes (op, equiv) as inputs and
returns a new op. # returns a new op.
Returns a Feature that can be added to an Env and mirrors all # Returns a Feature that can be added to an Env and mirrors all
modifications to that env with modifications to the target env. # modifications to that env with modifications to the target env.
""" # """
class Synchronize(gof.Listener, gof.Constraint): # class Synchronize(gof.Listener, gof.Constraint):
def __init__(self, source): # def __init__(self, source):
self.source = source # self.source = source
self.target = target # self.target = target
self.equiv = equiv # self.equiv = equiv
self.transform = transform # self.transform = transform
self.inconsistencies = [] # self.inconsistencies = []
def on_import(self, op1): # def on_import(self, op1):
if op1 not in self.equiv: # if op1 not in self.equiv:
op2 = self.transform(op1, self.equiv) # op2 = self.transform(op1, self.equiv)
self.equiv[op1] = op2 # self.equiv[op1] = op2
for o1, o2 in zip(op1.outputs, op2.outputs): # for o1, o2 in zip(op1.outputs, op2.outputs):
self.equiv[o1] = o2 # self.equiv[o1] = o2
def on_prune(self, op1): # def on_prune(self, op1):
if op1 in self.equiv: # if op1 in self.equiv:
op2 = self.equiv[op1] # op2 = self.equiv[op1]
del self.equiv[op1] # del self.equiv[op1]
for o1, o2 in zip(op1.outputs, op2.outputs): # for o1, o2 in zip(op1.outputs, op2.outputs):
del self.equiv[o1] # del self.equiv[o1]
def on_rewire(self, clients1, r1, new_r1): # def on_rewire(self, clients1, r1, new_r1):
if (new_r1, r1) in self.inconsistencies: # if (new_r1, r1) in self.inconsistencies:
self.inconsistencies.remove((new_r1, r1)) # self.inconsistencies.remove((new_r1, r1))
return # return
if not self.source.clients(r1): # if not self.source.clients(r1):
try: # try:
target.replace(self.equiv[r1], self.equiv[new_r1]) # target.replace(self.equiv[r1], self.equiv[new_r1])
except: # except:
self.inconsistencies.append((r1, new_r1)) # self.inconsistencies.append((r1, new_r1))
def validate(self): # def validate(self):
if self.inconsistencies: # if self.inconsistencies:
raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies) # raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
return True # return True
return Synchronize # return Synchronize
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论