提交 b801ba52 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

......@@ -25,11 +25,11 @@ class _test_DimShuffle(unittest.TestCase):
ib = [(entry == 1) for entry in xsh]
x = Tensor('float64', ib)('x')
e = DimShuffle(ib, shuffle)(x)
f = linker(Env([x], [e])).make_function()
f = copy(linker).accept(Env([x], [e])).make_function()
assert f(numpy.ones(xsh)).shape == zsh
def test_perform(self):
self.with_linker(gof.PerformLinker)
self.with_linker(gof.PerformLinker())
class _test_Broadcast(unittest.TestCase):
......@@ -47,7 +47,7 @@ class _test_Broadcast(unittest.TestCase):
x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(add)(x, y)
f = linker(Env([x, y], [e])).make_function()
f = copy(linker).accept(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv
......@@ -66,7 +66,7 @@ class _test_Broadcast(unittest.TestCase):
x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(Add(transfer_type(0)), {0:0})(x, y)
f = linker(Env([x, y], [e])).make_function()
f = copy(linker).accept(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv
......@@ -76,22 +76,22 @@ class _test_Broadcast(unittest.TestCase):
self.failUnless((xv == zv).all())
def test_perform(self):
self.with_linker(gof.PerformLinker)
self.with_linker(gof.PerformLinker())
def test_c(self):
self.with_linker(gof.CLinker)
self.with_linker(gof.CLinker())
def test_perform_inplace(self):
self.with_linker_inplace(gof.PerformLinker)
self.with_linker_inplace(gof.PerformLinker())
def test_c_inplace(self):
self.with_linker_inplace(gof.CLinker)
self.with_linker_inplace(gof.CLinker())
def test_fill(self):
x = Tensor('float64', [0, 0])('x')
y = Tensor('float64', [1, 1])('y')
e = Elemwise(Second(transfer_type(0)), {0:0})(x, y)
f = gof.CLinker(Env([x, y], [e])).make_function()
f = gof.CLinker().accept(Env([x, y], [e])).make_function()
xv = numpy.ones((5, 5))
yv = numpy.random.rand(1, 1)
f(xv, yv)
......@@ -101,7 +101,7 @@ class _test_Broadcast(unittest.TestCase):
x = Tensor('float64', [0, 0, 0, 0, 0])('x')
y = Tensor('float64', [0, 0, 0, 0, 0])('y')
e = Elemwise(add)(x, y)
f = gof.CLinker(Env([x, y], [e])).make_function()
f = gof.CLinker().accept(Env([x, y], [e])).make_function()
xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
zv = xv + yv
......@@ -110,7 +110,7 @@ class _test_Broadcast(unittest.TestCase):
def test_same_inputs(self):
x = Tensor('float64', [0, 0])('x')
e = Elemwise(add)(x, x)
f = gof.CLinker(Env([x], [e])).make_function()
f = gof.CLinker().accept(Env([x], [e])).make_function()
xv = numpy.random.rand(2, 2)
zv = xv + xv
assert (f(xv) == zv).all()
......@@ -129,7 +129,7 @@ class _test_CAReduce(unittest.TestCase):
x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
e = CAReduce(add, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh))
f = linker(Env([x], [e])).make_function()
f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh))
zv = xv
for axis in reversed(sorted(tosum)):
......@@ -137,10 +137,10 @@ class _test_CAReduce(unittest.TestCase):
self.failUnless((numpy.abs(f(xv) - zv) < 1e-10).all())
def test_perform(self):
self.with_linker(gof.PerformLinker)
self.with_linker(gof.PerformLinker())
def test_c(self):
self.with_linker(gof.CLinker)
self.with_linker(gof.CLinker())
if __name__ == '__main__':
......
......@@ -17,7 +17,7 @@ class _test_ScalarOps(unittest.TestCase):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
g = Env([x, y], [e])
fn = gof.DualLinker(g).make_function()
fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5
......@@ -30,7 +30,7 @@ class _test_composite(unittest.TestCase):
c = C.make_node(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = Env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5
def test_with_constants(self):
......@@ -41,7 +41,7 @@ class _test_composite(unittest.TestCase):
assert "70.0" in c.op.c_code(c, 'dummy', ['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = Env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self):
......@@ -53,79 +53,79 @@ class _test_composite(unittest.TestCase):
c = C.make_node(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
g = Env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function()
fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
class _test_logical(unittest.TestCase):
def test_gt(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x > y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a>b))
def test_lt(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x < y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x < y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a<b))
def test_le(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x <= y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x <= y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a<=b))
def test_ge(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x >= y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x >= y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a>=b))
def test_eq(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [eq(x,y)])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [eq(x,y)])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a==b))
def test_neq(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [neq(x,y)])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [neq(x,y)])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a!=b))
def test_or(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x|y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x|y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a|b), (a,b))
def test_xor(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x^y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x^y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a ^ b), (a,b))
def test_and(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [and_(x, y)])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [and_(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a & b), (a,b))
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x & y])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [x & y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a & b), (a,b))
def test_not(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [invert(x)])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [invert(x)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == ~a, (a,))
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [~x])).make_function()
fn = gof.DualLinker().accept(Env([x,y], [~x])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == ~a, (a,))
......
......@@ -56,7 +56,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
try:
f = function(inputrs, node.outputs,
linker = lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs),
linker = 'c&py', ##lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs),
unpack_single = False,
optimizer = None)
except:
......@@ -115,7 +115,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
try:
f = function(inputrs, node.outputs,
linker = lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs),
linker = 'c&py', #lambda env, **kwargs: gof.DualLinker(env, checker = _numpy_checker, **kwargs),
unpack_single = False,
optimizer = None)
except:
......@@ -1045,7 +1045,7 @@ class T_add(unittest.TestCase):
("*", lambda x,y: x*y),
("/", lambda x,y: x/y))
for s, fn in tests:
f = function([a,b], [fn(a, b)], linker = gof.CLinker)
f = function([a,b], [fn(a, b)], linker = 'c')
self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data)))
def test_grad_scalar_l(self):
......@@ -1354,9 +1354,9 @@ class t_gemm(unittest.TestCase):
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker)
cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker)
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
......@@ -1456,7 +1456,7 @@ class t_gemm(unittest.TestCase):
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker,dt='float64'):
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
......@@ -1699,7 +1699,7 @@ class _test_grad(unittest.TestCase):
if __name__ == '__main__':
if 0:
if 1:
unittest.main()
else:
suite = unittest.TestLoader()
......
......@@ -3,17 +3,18 @@
import numpy
import gof
import sys
from copy import copy
#TODO: put together some default optimizations (TRAC #67)
def exec_py_opt(inputs, outputs, features=[]):
"""Return an optimized graph running purely python implementations"""
return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker, False)
return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker(), False)
exec_py_opt.optimizer = None
def exec_opt(inputs, outputs, features=[]):
"""Return a fast implementation"""
return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker, False)
return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker(), False)
exec_opt.optimizer = None
class _DefaultOptimizer(object):
......@@ -28,18 +29,20 @@ def _mark_indestructible(results):
for r in results:
r.tag.indestructible = True
def linker_cls_python_and_c(env, **kwargs):
"""Use this as the linker_cls argument to Function.__init__ to compare
python and C implementations"""
def checker(x, y):
x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
else:
if x != y:
# def linker_cls_python_and_c(env, **kwargs):
# """Use this as the linker_cls argument to Function.__init__ to compare
# python and C implementations"""
def check_equal(x, y):
x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
else:
if x != y:
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
return gof.DualLinker(env, checker, **kwargs)
# return gof.DualLinker(checker, **kwargs).accept(env)
def infer_reuse_pattern(env, outputs_to_disown):
......@@ -86,10 +89,10 @@ def std_opt(env):
predefined_linkers = {
'py' : gof.link.PerformLinker,
'c' : gof.cc.CLinker,
'c|py' : gof.cc.OpWiseCLinker,
'c&py' : linker_cls_python_and_c
'py' : gof.PerformLinker(),
'c' : gof.CLinker(),
'c|py' : gof.OpWiseCLinker(),
'c&py' : gof.DualLinker(checker = check_equal)
}
class FunctionFactory:
......@@ -105,14 +108,14 @@ class FunctionFactory:
optimizer(env)
env.validate()
self.env = env
linker = predefined_linkers.get(linker, linker)
if not callable(linker):
raise ValueError("'linker' parameter of FunctionFactory should be a callable that takes an env as argument " \
linker = copy(predefined_linkers.get(linker, linker))
if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionFactory should be a Linker with an accept method " \
"or one of ['py', 'c', 'c|py', 'c&py']")
if borrow_outputs:
self.linker = linker(env)
self.linker = linker.accept(env)
else:
self.linker = linker(env, no_recycling = infer_reuse_pattern(env, env.outputs))
self.linker = linker.accept(env, no_recycling = infer_reuse_pattern(env, env.outputs))
def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'):
......
......@@ -2,7 +2,7 @@
import elemwise_cgen as cgen
import numpy
from gof import Op, Apply
from gof import Op, Macro, Apply
import scalar
from scalar import Scalar
import gof
......@@ -29,6 +29,50 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ###
##################
## TODO: rule-based version of DimShuffle
## would allow for Transpose, LComplete, RComplete, etc.
## Can be optimized into DimShuffle later on.
class ShuffleRule(Macro):
"""
ABSTRACT Op - it has no perform and no c_code
Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed.
"""
def __init__(self, rule = None, name = None):
if rule is not None:
self.rule = rule
self.name = name
def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return gof.Apply(self,
(input,) + models,
[Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' for x in pattern]).make_result()])
def expand(self, r):
input, models = r.owner.inputs[0], r.owner.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return DimShuffle(input.type.broadcastable, new_order)(input)
def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other):
return hash(self.rule)
def __str__(self):
if self.name is not None:
return self.name
else:
return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
name = 'rcomplete')
class DimShuffle(Op):
"""
......@@ -182,6 +226,23 @@ class DimShuffle(Op):
return DimShuffle(gz.type.broadcastable, grad_order)(gz),
# class LComplete(Op):
# view_map = {0: [0]}
# def make_node(self, x, y):
# x, y = map(as_tensor, (x, y))
# xd, yd = x.type.ndim, y.type.ndim
# if xd > yd:
# raise TypeError("The tensor to left-complete has more dimensions than the model.")
# return gof.Apply(self,
# [x, y],
# [Tensor(dtype = x.type.dtype,
# broadcastable = (True,)*(yd-xd) + x.type.broadcastable).make_result()])
# def perform(self, node, (x, y), (z, )):
# return x.reshape((1, )*(y.ndim - x.ndim) + tuple(x.shape))
# def grad(self, node, (x, ), (gz, )):
# xd, gzd = x.type.ndim, gz.type.ndim
# return DimShuffle(gz.broadcastable, range(gzd-xd, xd))(gz)
################
### Elemwise ###
......@@ -243,20 +304,23 @@ class Elemwise(Op):
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
target_length = max([input.type.ndim for input in inputs])
args = []
for input in inputs:
length = input.type.ndim
difference = target_length - length
if not difference:
args.append(input)
else:
args.append(DimShuffle(range(length), ['x']*difference + range(length))(input))
inputs = args
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
if len(inputs) > 1:
inputs = [lcomplete(input, *inputs) for input in inputs]
# args = []
# for input in inputs:
# length = input.type.ndim
# difference = target_length - length
# if not difference:
# args.append(input)
# else:
# # TODO: use LComplete instead
# args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# inputs = args
try:
assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout
inplace_pattern = self.inplace_pattern
......@@ -508,7 +572,13 @@ class CAReduce(Op):
output = Tensor(dtype = input.type.dtype,
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
return Apply(self, [input], [output])
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis
def __hash__(self):
return hash(self.scalar_op) ^ hash(self.axis)
def __str__(self):
if self.axis is not None:
return "Reduce{%s}{%s}" % (self.scalar_op, ", ".join(str(x) for x in self.axis))
......
......@@ -11,6 +11,23 @@ def _zip(*lists):
else:
return zip(*lists)
# x = ivector()
# y = ivector()
# e = x + y
# f = Formula(x = x, y = y, e = e)
# y = x + x
# g = Formula(x=x,y=y)
# x2 = x + x
# g = Formula(x=x, x2=x2)
class Formula(utils.object2):
def __init__(self, symtable_d = {}, **symtable_kwargs):
......@@ -257,11 +274,23 @@ class Formulas(utils.object2):
# class Test(Formulas):
# x = T.ivector()
# y = T.ivector()
# e = x + y + 21
# x = T.ivector()
# y = T.ivector()
# e = x + y + 21
# f1 = Formula(x = x, y = y, e = e)
# Test() -> f1.clone()
# f = Test()
# print f
# print f.reassign(x = T.ivector())
......@@ -317,6 +346,8 @@ class Formulas(utils.object2):
# lr = 0.01
# def autoassociator_f(x, w, b, c):
# reconstruction = sigmoid(T.dot(sigmoid(T.dot(x, w) + b), w.T) + c)
......
......@@ -4,7 +4,7 @@ from env import InconsistencyError, Env
from ext import DestroyHandler, view_roots
from graph import Apply, Result, Constant, Value
from link import Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
from op import Op
from op import Op, Macro
from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge
from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from type import Type, Generic, generic
......
......@@ -140,7 +140,7 @@ class _test_CLinker(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(Env([x, y, z], [e]))
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
......@@ -158,7 +158,7 @@ class _test_CLinker(unittest.TestCase):
x, y, z = inputs()
z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(Env([x, y], [e]))
lnk = CLinker().accept(Env([x, y], [e]))
fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
......@@ -166,7 +166,7 @@ class _test_CLinker(unittest.TestCase):
def test_single_node(self):
x, y, z = inputs()
node = add.make_node(x, y)
lnk = CLinker(Env(node.inputs, node.outputs))
lnk = CLinker().accept(Env(node.inputs, node.outputs))
fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9)
......@@ -174,7 +174,7 @@ class _test_CLinker(unittest.TestCase):
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, x)
lnk = CLinker(Env([x, x], [e]))
lnk = CLinker().accept(Env([x, x], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined
......@@ -183,7 +183,7 @@ class _test_CLinker(unittest.TestCase):
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker(Env([x, y, z], [e]))
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
......@@ -194,7 +194,7 @@ class _test_OpWiseCLinker(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(Env([x, y, z], [e]))
lnk = OpWiseCLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
......@@ -202,7 +202,7 @@ class _test_OpWiseCLinker(unittest.TestCase):
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker(Env([y, z], [e]))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function()
res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res)
......@@ -220,7 +220,7 @@ class _test_DualLinker(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(Env([x, y, z], [e]), checker = _my_checker)
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0)
self.failUnless(res == 15.3, res)
......@@ -229,12 +229,12 @@ class _test_DualLinker(unittest.TestCase):
x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = Env([x, y, z], [e])
lnk = DualLinker(g, checker = _my_checker)
lnk = DualLinker(checker = _my_checker).accept(g)
fn = lnk.make_function()
self.failUnless(CLinker(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(OpWiseCLinker(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(PerformLinker(g).make_function()(1.0, 2.0, 3.0) == -10.0) # (purposely) wrong
self.failUnless(CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0) # (purposely) wrong
try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds
......
......@@ -70,7 +70,7 @@ def inputs():
return x, y, z
def perform_linker(env):
lnk = PerformLinker(env)
lnk = PerformLinker().accept(env)
return lnk
def Env(inputs, outputs):
......
......@@ -3,7 +3,7 @@ import unittest
from type import Type
from graph import Result, Apply, Constant
from op import Op
from op import Op, Macro
from opt import *
from env import Env
from toolbox import *
......@@ -415,6 +415,38 @@ class _test_MergeOptimizer(unittest.TestCase):
class _test_ExpandMacro(unittest.TestCase):
def test_straightforward(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [op1(y, x)]
x, y, z = inputs()
e = Macro1()(x, y)
g = Env([x, y], [e])
print g
expand_macros.optimize(g)
print g
def test_loopy(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [Macro1()(y, x)]
x, y, z = inputs()
e = Macro1()(x, y)
g = Env([x, y], [e])
print g
#expand_macros.optimize(g)
TopDownOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g)
print g
if __name__ == '__main__':
unittest.main()
......
......@@ -339,10 +339,16 @@ class CLinker(link.Linker):
associated to it during the computation (to avoid reusing it).
"""
def __init__(self, env, no_recycling = []):
def __init__(self):
self.env = None
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env
self.fetch_results()
self.no_recycling = no_recycling
return self
def fetch_results(self):
"""
......@@ -771,10 +777,16 @@ class OpWiseCLinker(link.LocalLinker):
associated to it during the computation (to avoid reusing it).
"""
def __init__(self, env, fallback_on_perform = True, no_recycling = []):
self.env = env
def __init__(self, fallback_on_perform = True):
self.env = None
self.fallback_on_perform = fallback_on_perform
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env
self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
......@@ -795,7 +807,7 @@ class OpWiseCLinker(link.LocalLinker):
try:
e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
cl = CLinker(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
......@@ -848,7 +860,7 @@ class DualLinker(link.Linker):
function.
"""
def __init__(self, env, checker = _default_checker, no_recycling = []):
def __init__(self, checker = _default_checker):
"""
Initialize a DualLinker.
......@@ -871,17 +883,23 @@ class DualLinker(link.Linker):
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
self.env = env
self.env = None
self.checker = checker
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env
self.no_recycling = no_recycling
return self
def make_thunk(self, **kwargs):
env = self.env
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = link.PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i1, o1, thunks1, order1 = link.PerformLinker().accept(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker().accept(env, no_recycling = no_recycling).make_all(**kwargs)
def f():
for input1, input2 in zip(i1, i2):
......
import link
from functools import partial
class DebugException(Exception):
pass
class DebugLinker(link.MetaLinker):
def __init__(self,
env,
linkers,
debug_pre = [],
debug_post = [],
copy_originals = False,
check_types = True,
compare_results = True,
no_recycling = [],
compare_fn = lambda x, y: x == y):
link.MetaLinker.__init__(self, env = env,
linkers = linkers,
wrapper = self.wrapper,
no_recycling = no_recycling)
self.compare_fn = compare_fn
self.copy_originals = copy_originals
if check_types not in [None, True]:
self.check_types = check_types
if compare_results not in [None, True]:
self.compare_results = compare_results
if not isinstance(debug_pre, (list, tuple)):
debug_pre = [debug_pre]
self.debug_pre = debug_pre
if not isinstance(debug_post, (list, tuple)):
debug_post = [debug_post]
self.debug_post = debug_post
if check_types is not None:
self.debug_post.append(self.check_types)
if compare_results is not None:
self.debug_post.append(self.compare_results)
def store_value(self, i, node, *thunks):
th1 = thunks[0]
for r, oval in zip(node.outputs, th1.outputs):
r.step = i
r.value = oval[0]
if self.copy_originals:
r.original_value = copy(oval[0])
def check_types(self, debug, i, node, *thunks):
for thunk, linker in zip(thunks, self.linkers):
for r in node.outputs:
try:
r.type.filter(r.value, strict = True)
except TypeError, e:
exc_type, exc_value, exc_trace = sys.exc_info()
exc = DebugException(e, "The output %s was filled with data with the wrong type using linker " \
("%s. This happened at step %i of the program." % (r, linker, i)) + \
"For more info, inspect this exception's 'original_exception', 'debugger', " \
"'output_at_fault', 'step', 'node', 'thunk' and 'linker' fields.")
exc.debugger = self
exc.original_exception = e
exc.output_at_fault = r
exc.step = i
exc.node = node
exc.thunk = thunk
exc.linker = linker
raise DebugException, exc, exc_trace
def compare_results(self, debug, i, node, *thunks):
thunk0 = thunks[0]
linker0 = self.linkers[0]
for thunk, linker in zip(thunks[1:], self.linkers[1:]):
for o, output0, output in zip(node.outputs, thunk0.outputs, thunk.outputs):
if not self.compare_fn(output0[0], output[0]):
exc = DebugException(("The results from %s and %s for output %s are not the same. This happened at step %i." % (linker0, linker, o, step)) + \
"For more info, inspect this exception's 'debugger', 'output', 'output_value1', 'output_value2', " \
"'step', 'node', 'thunk1', 'thunk2', 'linker1' and 'linker2' fields.")
exc.debugger = self
exc.output = o
exc.output_value1 = output0
exc.output_value2 = output
exc.step = i
exc.node = node
exc.thunk1 = thunk0
exc.thunk2 = thunk
exc.linker1 = linker0
exc.linker2 = linker
raise exc
def pre(self, f, inputs, order, thunk_groups):
env = f.env
for r in env.results:
if r.owner is None:
r.step = "value" # this will be overwritten if r is an input
else:
r.step = None
r.value = None
r.original_value = None
if r.owner is None and r not in env.inputs:
r.value = r.data
if self.copy_originals:
r.original_value = copy(r.data)
for idx, (i, r) in enumerate(zip(inputs, env.inputs)):
r.step = "input %i" % idx
r.value = i
if self.copy_originals:
r.original_value = copy(i)
for node, thunk_group in zip(order, thunk_groups):
node.step = None
def wrapper(self, th, i, node, *thunks):
try:
node.step = i
for f in self.debug_pre:
f(th, i, node, *thunks)
for thunk in thunks:
thunk()
self.store_value(i, node, *thunks)
for f in self.debug_post:
f(th, i, node, *thunks)
except Exception, e:
exc_type, exc_value, exc_trace = sys.exc_info()
if isinstance(e, DebugException):
raise
exc = DebugException(e, ("An exception occurred while processing node %s at step %i of the program." % (node, i)) + \
"For more info, inspect this exception's 'original_exception', 'debugger', 'step', 'node' and 'thunks' fields.")
exc.debugger = self
exc.original_exception = e
exc.step = i
exc.node = node
exc.thunks = thunks
raise DebugException, exc, exc_trace
def make_thunk(self, **kwargs):
inplace = kwargs.pop("inplace", False)
if inplace:
e, equiv = self.env, None
else:
e, equiv = self.env.clone_get_equiv()
class Debug:
def __init__(self, thunk, env, equiv):
self.thunk = thunk
self.env = env
self.equiv = equiv
def __call__(self):
self.thunk()
def __getitem__(self, item):
equiv = self.equiv
if not isinstance(item, Apply) and not isinstance(item, Result):
raise TypeError("__getitem__ expects an Apply or Result instance.")
if not hasattr(item, 'env') or item.env is not e:
if equiv is None:
raise Exception("item does not belong to this graph and has no equivalent")
else:
return equiv[item]
else:
return item
bk = self.no_recycling
self.no_recycling = map(equiv.__getitem__, self.no_recycling)
th, inputs, outputs = link.MetaLinker.make_thunk(self, alt_env = e, wrapf = lambda f: Debug(f, e, equiv), **kwargs)
self.no_recycling = bk
return th, inputs, outputs
......@@ -76,21 +76,36 @@ class Apply(utils.object2):
cp = self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs])
cp.tag = copy(self.tag)
return cp
def clone_with_new_inputs(self, inputs, check_type = True):
def clone_with_new_inputs(self, inputs, strict = True):
"""
Returns an Apply node with the same op but different inputs. Unless
check_type is False, the type fields of all the inputs must be
strict is False, the type fields of all the inputs must be
equal to the current ones.
The outputs of the clone will have the same type as the outputs of
self.
If strict is True, the outputs of the clone will have the same type as
the outputs of self. Else, it depends on the types of the new inputs
and the behavior of the op wrt that.
"""
if check_type:
for curr, new in zip(self.inputs, inputs):
if not curr.type == new.type:
# if check_type:
# for curr, new in zip(self.inputs, inputs):
# if not curr.type == new.type:
# raise TypeError("Cannot change the type of this input.", curr, new)
# new_node = self.clone()
# new_node.inputs = inputs
# return new_node
remake_node = False
for curr, new in zip(self.inputs, inputs):
if not curr.type == new.type:
if strict:
raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone()
new_node.inputs = inputs
else:
remake_node = True
if remake_node:
new_node = self.op.make_node(*inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
else:
new_node = self.clone()
new_node.inputs = inputs
return new_node
nin = property(lambda self: len(self.inputs), doc = 'same as len(self.inputs)')
......@@ -367,6 +382,94 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
return d
# def clone_with_new_inputs(i, o, new_i):
# equiv = clone_with_new_inputs_get_equiv(i, o, new_i)
# return [equiv[input] for input in i], [equiv[output] for output in o]
# def clone_with_new_inputs_get_equiv(i, o, new_i, copy_orphans = True):
# # note: this does not exactly mirror Apply.clone_with_new_inputs
# # here it is possible to give different types to new_i and then
# # make_node is called on the ops instead of clone_with_new_inputs
# # whenever the type is different.
# d = {}
# for input, new_input in zip(i, new_i):
# d[input] = new_input
# def clone_helper(result):
# if result in d:
# return d[result]
# node = result.owner
# if node is None: # result is an orphan
# if copy_orphans:
# cpy = result.clone()
# d[result] = cpy
# else:
# d[result] = result
# return d[result]
# else:
# cloned_inputs = [clone_helper(input) for input in node.inputs]
# if any(input != cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
# new_node = node.op.make_node(*cloned_inputs)
# else:
# new_node = node.clone_with_new_inputs(cloned_inputs)
# d[node] = new_node
# for output, new_output in zip(node.outputs, new_node.outputs):
# d[output] = new_output
# return d[result]
# for output in o:
# clone_helper(output)
# return d
def clone_with_equiv(i, o, d, missing_input_policy = 'fail', orphan_policy = 'copy'):
def clone_helper(result):
if result in d:
return d[result]
node = result.owner
if node is None: # result is an input or an orphan not in d
if isinstance(result, Value):
if orphan_policy == 'copy':
d[result] = copy(result)
elif orphan_policy == 'keep':
d[result] = result
else:
raise ValueError("unknown orphan_policy: '%s'" % orphan_policy)
else:
if missing_input_policy == 'fail':
raise ValueError("missing input: %s" % result)
elif missing_input_policy == 'keep':
d[result] = result
else:
raise ValueError("unknown missing_input_policy: '%s'" % missing_input_policy)
return d[result]
else:
cloned_inputs = [clone_helper(input) for input in node.inputs]
if all(input is cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
new_node = node
else:
new_node = node.clone_with_new_inputs(cloned_inputs, strict = False)
# if any(input != cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
# new_node = node.op.make_node(*cloned_inputs)
# else:
# new_node = node.clone_with_new_inputs(cloned_inputs)
d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output
return d[result]
for output in o:
clone_helper(output)
return [d[input] for input in i], [d[output] for output in o]
def general_toposort(r_out, deps):
"""
@note: deps(i) should behave like a pure function (no funny business with
......
......@@ -196,9 +196,15 @@ class PerformLinker(LocalLinker):
the L{Env} in the order given by L{Env.toposort}.
"""
def __init__(self, env, no_recycling = []):
def __init__(self):
self.env = None
def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env:
raise Exception("Cannot accept from a Linker that is already tied to another Env.")
self.env = env
self.no_recycling = no_recycling
return self
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
......
......@@ -126,3 +126,15 @@ class Op(utils.object2):
class Macro(Op):
"""
Abstract Op which does not have an implementation but can be expanded
into a computable graph with its expand() method.
"""
def expand(self, *outputs):
"""
Returns a node representing the expansion of this macro.
"""
raise utils.AbstractFunctionError()
......@@ -9,6 +9,7 @@ from env import InconsistencyError
import utils
import unify
import toolbox
import op
class Optimizer:
......@@ -480,3 +481,69 @@ def MergeOptMerge(opt):
merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger])
class LocalOptimizer:
def applies(self, node):
raise utils.AbstractFunctionError()
def transform(self, node):
raise utils.AbstractFunctionError()
class ExpandMacro:
def applies(self, node):
return isinstance(node.op, op.Macro)
def transform(self, node):
return node.op.expand(node)
from collections import deque
class TopDownOptimizer(Optimizer):
def __init__(self, local_opt, ignore_newtrees = False):
self.local_opt = local_opt
self.ignore_newtrees = ignore_newtrees
def apply(self, env):
ignore_newtrees = self.ignore_newtrees
q = deque()
class Updater:
def on_attach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
q.appendleft(node)
if not ignore_newtrees:
def on_import(self, env, node):
q.appendleft(node)
def on_prune(self, env, node):
if node is not current_node:
q.remove(node)
u = Updater()
env.extend(u)
while q:
node = q.popleft()
current_node = node
if not self.local_opt.applies(node):
continue
replacements = self.local_opt.transform(node)
for output, replacement in zip(node.outputs, replacements):
env.replace_validate(output, replacement)
env.remove_feature(u)
def add_requirements(self, env):
try:
env.extend(toolbox.ReplaceValidate())
except: pass
expand_macros = TopDownOptimizer(ExpandMacro())
......@@ -32,6 +32,8 @@ class object2(object):
class scratchpad:
def clear(self):
self.__dict__.clear()
def __update__(self, other):
self.__dict__.update(other.__dict__)
def __str__(self):
print "scratch" + str(self.__dict__)
......
import gof
class PrinterState(gof.utils.scratchpad):
def __init__(self, props = {}, **more_props):
if isinstance(props, gof.utils.scratchpad):
self.__update__(props)
else:
self.__dict__.update(props)
self.__dict__.update(more_props)
def clone(self, props = {}, **more_props):
return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter:
def __init__(self, operator, precedence, assoc = 'left'):
self.operator = operator
self.precedence = precedence
self.assoc = assoc
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("operator %s cannot represent a result with no associated operation" % self.operator)
outer_precedence = getattr(pstate, 'precedence', -999999)
outer_assoc = getattr(pstate, 'assoc', 'none')
if outer_precedence > self.precedence:
parenthesize = True
#elif outer_assoc != self.assoc:
# parenthesize = True
else:
parenthesize = False
input_strings = []
max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs):
if self.assoc == 'left' and i != 0 or self.assoc == 'right' and i != max_i:
s = pprinter.process(input, pstate.clone(precedence = self.precedence + 1e-6))
else:
s = pprinter.process(input, pstate.clone(precedence = self.precedence))
input_strings.append(s)
if len(input_strings) == 1:
s = self.operator + input_strings[0]
else:
s = (" %s " % self.operator).join(input_strings)
if parenthesize: return "(%s)" % s
else: return s
class FunctionPrinter:
def __init__(self, *names):
self.names = names
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a result with no associated operation" % self.function)
names = self.names
idx = node.outputs.index(output)
name = self.names[idx]
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = -1000))
for input in node.inputs]))
class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), pstate.pprinter.process(r))
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, ShuffleRule):
#print r, r.owner.op
new_r = r.owner.op.expand(r)
#print new_r.owner, isinstance(new_r.owner.op, ShuffleRule)
return self.process(new_r, pstate)
elif isinstance(r.owner.op, DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
else:
raise TypeError("Can only print DimShuffle.")
class DefaultPrinter:
def __init__(self):
pass
def process(self, r, pstate):
pprinter = pstate.pprinter
node = r.owner
if node is None:
return LeafPrinter().process(r, pstate)
return "%s(%s)" % (str(node.op), ", ".join([pprinter.process(input, pstate.clone(precedence = -1000))
for input in node.inputs]))
class LeafPrinter:
def process(self, r, pstate):
if r.name in greek:
return greek[r.name]
else:
return str(r)
special = dict(middle_dot = u"\u00B7",
big_sigma = u"\u03A3")
greek = dict(alpha = u"\u03B1",
beta = u"\u03B2",
gamma = u"\u03B3",
delta = u"\u03B4",
epsilon = u"\u03B5")
ppow = OperatorPrinter('**', 0, 'right')
pmul = OperatorPrinter('*', -1, 'either')
pdiv = OperatorPrinter('/', -1, 'left')
padd = OperatorPrinter('+', -2, 'either')
psub = OperatorPrinter('-', -2, 'left')
pdot = OperatorPrinter(special['middle_dot'], -1, 'left')
psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left')
plog = FunctionPrinter('log')
class PPrinter:
def __init__(self):
self.printers = []
def assign(self, condition, printer):
if isinstance(condition, gof.Op):
op = condition
condition = lambda pstate, r: r.owner is not None and r.owner.op == op
self.printers.insert(0, (condition, printer))
def process(self, r, pstate = None):
if pstate is None:
pstate = PrinterState(pprinter = self)
for condition, printer in self.printers:
if condition(pstate, r):
return printer.process(r, pstate)
def clone(self):
cp = copy(self)
cp.printers = list(self.printers)
return cp
def clone_assign(self, condition, printer):
cp = self.clone()
cp.assign(condition, printer)
return cp
# class ExtendedPPrinter:
# def __init__(self, pprinter, leaf_pprinter):
# self.pprinter = pprinter
# self.leaf_pprinter = pprinter
# def process(self, r, pstate = None):
from tensor import *
from elemwise import Sum, ShuffleRule
x, y, z = matrices('xyz')
pp = PPrinter()
pp.assign(lambda pstate, r: True, DefaultPrinter())
pp.assign(add, padd)
pp.assign(mul, pmul)
pp.assign(sub, psub)
pp.assign(neg, psub)
pp.assign(div, pdiv)
pp.assign(pow, ppow)
pp.assign(dot, pdot)
pp.assign(Sum(), FunctionPrinter('sum'))
pp.assign(sgrad, FunctionPrinter('d'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter())
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, ShuffleRule), DimShufflePrinter())
print pp.process(x + y * z)
print pp.process((x + y) * z)
print pp.process(x * (y * z))
print pp.process(x / (y / z) / x)
print pp.process((x ** y) ** z)
print pp.process(-x+y)
print pp.process(-x*y)
print pp.process(sum(x))
print pp.process(sum(x * 10))
a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha')
print a.type
print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a)
print pp.process(x / (y * z))
......@@ -253,6 +253,11 @@ def tensor(*args, **kwargs):
def _multi(*fns):
def f2(f, names):
if isinstance(names, int):
if names == 1:
return f()
else:
return [f() for i in xrange(names)]
if len(names) == 1:
return f(names)
else:
......@@ -639,6 +644,7 @@ def transpose(x, **kwargs):
class Subtensor(Op):
"""Return a subtensor view
......@@ -908,7 +914,7 @@ class Dot(Op):
def grad(self, (x, y), (gz,)):
return dot(gz, y.T), dot(x.T, gz)
def __str__(self):
return "Dot"
return "dot"
dot = Dot()
class Gemm(Op):
......@@ -1136,6 +1142,14 @@ gemm = Gemm()
# Gradient
#########################
class SGrad(gof.Op):
def make_node(self, cost, wrt):
return Apply(self, [cost, wrt], [wrt.type()])
def expand(self, r):
cost, wrt = r.owner.inputs
return grad(cost, wrt)
sgrad = SGrad()
def grad(cost, wrt, g_cost=None):
"""
@type cost: L{Result}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论