提交 2fd6ea22 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

switching from Op/Result to Op+Apply/Type+Result

上级 df258db8
import gof import gof
import tensor import tensor
import sparse #import sparse
import compile import compile
import gradient import gradient
import tensor_opt #import tensor_opt
import scalar_opt import scalar_opt
from tensor import *
from compile import * from compile import *
from tensor_opt import * #from tensor_opt import *
from scalar_opt import * from scalar_opt import *
from gradient import * from gradient import *
from tensor import *
...@@ -11,14 +11,14 @@ import tensor ...@@ -11,14 +11,14 @@ import tensor
from elemwise import * from elemwise import *
def inputs(): # def inputs():
x = modes.build(Tensor('float64', (0, 0), name = 'x')) # x = modes.build(Tensor('float64', (0, 0), name = 'x'))
y = modes.build(Tensor('float64', (1, 0), name = 'y')) # y = modes.build(Tensor('float64', (1, 0), name = 'y'))
z = modes.build(Tensor('float64', (0, 0), name = 'z')) # z = modes.build(Tensor('float64', (0, 0), name = 'z'))
return x, y, z # return x, y, z
def env(inputs, outputs, validate = True, features = []): # def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate) # return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_DimShuffle(unittest.TestCase): class _test_DimShuffle(unittest.TestCase):
...@@ -31,10 +31,11 @@ class _test_DimShuffle(unittest.TestCase): ...@@ -31,10 +31,11 @@ class _test_DimShuffle(unittest.TestCase):
((2, 3, 4), ('x', 2, 1, 0, 'x'), (1, 4, 3, 2, 1)), ((2, 3, 4), ('x', 2, 1, 0, 'x'), (1, 4, 3, 2, 1)),
((1, 4, 3, 2, 1), (3, 2, 1), (2, 3, 4)), ((1, 4, 3, 2, 1), (3, 2, 1), (2, 3, 4)),
((1, 1, 4), (1, 2), (1, 4))]: ((1, 1, 4), (1, 2), (1, 4))]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) ib = [(entry == 1) for entry in xsh]
e = DimShuffle(x, shuffle).out x = Tensor('float64', ib)('x')
e = DimShuffle(ib, shuffle)(x)
# print shuffle, e.owner.grad(e.owner.inputs, e.owner.outputs).owner.new_order # print shuffle, e.owner.grad(e.owner.inputs, e.owner.outputs).owner.new_order
f = linker(env([x], [e])).make_function(inplace=False) f = linker(Env([x], [e])).make_function()
assert f(numpy.ones(xsh)).shape == zsh assert f(numpy.ones(xsh)).shape == zsh
def test_perform(self): def test_perform(self):
...@@ -53,10 +54,10 @@ class _test_Broadcast(unittest.TestCase): ...@@ -53,10 +54,10 @@ class _test_Broadcast(unittest.TestCase):
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((), ())]: ((), ())]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = modes.build(Tensor('float64', [1 * (entry == 1) for entry in ysh], name = 'y')) y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Broadcast(Add, (x, y)).out e = Elemwise(add)(x, y)
f = linker(env([x, y], [e])).make_function(inplace = False) f = linker(Env([x, y], [e])).make_function()
# xv = numpy.array(range(numpy.product(xsh))) # xv = numpy.array(range(numpy.product(xsh)))
# xv = xv.reshape(xsh) # xv = xv.reshape(xsh)
# yv = numpy.array(range(numpy.product(ysh))) # yv = numpy.array(range(numpy.product(ysh)))
...@@ -80,10 +81,10 @@ class _test_Broadcast(unittest.TestCase): ...@@ -80,10 +81,10 @@ class _test_Broadcast(unittest.TestCase):
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((), ())]: ((), ())]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = modes.build(Tensor('float64', [1 * (entry == 1) for entry in ysh], name = 'y')) y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Broadcast(Add, (x, y), {0:0}).out e = Elemwise(Add(transfer_type(0)), {0:0})(x, y)
f = linker(env([x, y], [e])).make_function(inplace = False) f = linker(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv zv = xv + yv
...@@ -105,29 +106,29 @@ class _test_Broadcast(unittest.TestCase): ...@@ -105,29 +106,29 @@ class _test_Broadcast(unittest.TestCase):
self.with_linker_inplace(gof.CLinker) self.with_linker_inplace(gof.CLinker)
def test_fill(self): def test_fill(self):
x = modes.build(Tensor('float64', [0, 0], name = 'x')) x = Tensor('float64', [0, 0])('x')
y = modes.build(Tensor('float64', [1, 1], name = 'y')) y = Tensor('float64', [1, 1])('y')
e = Broadcast(Second, (x, y), {0:0}).out e = Elemwise(Second(transfer_type(0)), {0:0})(x, y)
f = gof.CLinker(env([x, y], [e])).make_function(inplace = False) f = gof.CLinker(Env([x, y], [e])).make_function()
xv = numpy.ones((5, 5)) xv = numpy.ones((5, 5))
yv = numpy.random.rand(1, 1) yv = numpy.random.rand(1, 1)
f(xv, yv) f(xv, yv)
assert (xv == yv).all() assert (xv == yv).all()
def test_weird_strides(self): def test_weird_strides(self):
x = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'x')) x = Tensor('float64', [0, 0, 0, 0, 0])('x')
y = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'y')) y = Tensor('float64', [0, 0, 0, 0, 0])('y')
e = Broadcast(Add, (x, y)).out e = Elemwise(add)(x, y)
f = gof.CLinker(env([x, y], [e])).make_function(inplace = False) f = gof.CLinker(Env([x, y], [e])).make_function()
xv = numpy.random.rand(2, 2, 2, 2, 2) xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2) yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
zv = xv + yv zv = xv + yv
assert (f(xv, yv) == zv).all() assert (f(xv, yv) == zv).all()
def test_same_inputs(self): def test_same_inputs(self):
x = modes.build(Tensor('float64', [0, 0], name = 'x')) x = Tensor('float64', [0, 0])('x')
e = Broadcast(Add, (x, x)).out e = Elemwise(add)(x, x)
f = gof.CLinker(env([x], [e])).make_function(inplace = False) f = gof.CLinker(Env([x], [e])).make_function()
xv = numpy.random.rand(2, 2) xv = numpy.random.rand(2, 2)
zv = xv + xv zv = xv + xv
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
...@@ -136,15 +137,17 @@ class _test_Broadcast(unittest.TestCase): ...@@ -136,15 +137,17 @@ class _test_Broadcast(unittest.TestCase):
class _test_CAReduce(unittest.TestCase): class _test_CAReduce(unittest.TestCase):
def with_linker(self, linker): def with_linker(self, linker):
for xsh, tosum in [((5, 6), (0, 1)), for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)),
((5, 6), (0, )), ((5, 6), (0, )),
((5, 6), (1, )), ((5, 6), (1, )),
((5, 6), ()), ((5, 6), ()),
((2, 3, 4, 5), (0, 1, 3)), ((2, 3, 4, 5), (0, 1, 3)),
((), ())]: ((), ())]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
e = CAReduce(Add, [x], axis = tosum).out e = CAReduce(add, axis = tosum)(x)
f = linker(env([x], [e])).make_function(inplace = False) if tosum is None: tosum = range(len(xsh))
f = linker(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
zv = xv zv = xv
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
......
import unittest import unittest
from gof import Result, Op, Env, modes from gof import Result, Op, Env
import gof import gof
from scalar import * from scalar import *
def inputs(): def inputs():
x = modes.build(as_scalar(1.0, 'x')) return floats('xyz')
y = modes.build(as_scalar(2.0, 'y'))
z = modes.build(as_scalar(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_ScalarOps(unittest.TestCase): class _test_ScalarOps(unittest.TestCase):
...@@ -22,7 +16,7 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -22,7 +16,7 @@ class _test_ScalarOps(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
g = env([x, y], [e]) g = Env([x, y], [e])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
...@@ -32,25 +26,21 @@ class _test_composite(unittest.TestCase): ...@@ -32,25 +26,21 @@ class _test_composite(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
C = composite([x, y], [e]) C = Composite([x, y], [e])
c = C(x, y) c = C.make_node(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform() g = Env([x, y], [c.out])
assert c.outputs[0].data == 1.5
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
def test_with_constants(self): def test_with_constants(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(70.0, y), div(x, y)) e = mul(add(70.0, y), div(x, y))
C = composite([x, y], [e]) C = Composite([x, y], [e])
c = C(x, y) c = C.make_node(x, y)
assert "70.0" in c.c_code(['x', 'y'], ['z'], dict(id = 0)) assert "70.0" in c.op.c_code(c, 'dummy', ['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform() g = Env([x, y], [c.out])
assert c.outputs[0].data == 36.0
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
...@@ -59,14 +49,10 @@ class _test_composite(unittest.TestCase): ...@@ -59,14 +49,10 @@ class _test_composite(unittest.TestCase):
e0 = x + y + z e0 = x + y + z
e1 = x + y * z e1 = x + y * z
e2 = x / y e2 = x / y
C = composite([x, y, z], [e0, e1, e2]) C = Composite([x, y, z], [e0, e1, e2])
c = C(x, y, z) c = C.make_node(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0)) # print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
c.perform() g = Env([x, y, z], c.outputs)
assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5
g = env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
......
...@@ -10,53 +10,43 @@ from scalar_opt import * ...@@ -10,53 +10,43 @@ from scalar_opt import *
def inputs(): def inputs():
x = Scalar('float64', name = 'x') return floats('xyz')
y = Scalar('float64', name = 'y')
z = Scalar('float64', name = 'z')
a = Scalar('float64', name = 'a')
return x, y, z
def more_inputs(): def more_inputs():
a = Scalar('float64', name = 'a') return floats('abcd')
b = Scalar('float64', name = 'b')
c = Scalar('float64', name = 'c')
d = Scalar('float64', name = 'd')
return a, b, c, d
class _test_opts(unittest.TestCase): class _test_opts(unittest.TestCase):
def test_pow_to_sqr(self): def test_pow_to_sqr(self):
x, y, z = inputs() x, y, z = floats('xyz')
e = x ** 2.0 e = x ** 2.0
g = Env([x], [e]) g = Env([x], [e])
assert str(g) == "[Pow(x, 2.0)]" assert str(g) == "[pow(x, 2.0)]"
gof.ConstantFinder().optimize(g)
pow2sqr_float.optimize(g) pow2sqr_float.optimize(g)
assert str(g) == "[Sqr(x)]" assert str(g) == "[sqr(x)]"
# class _test_canonize(unittest.TestCase): class _test_canonize(unittest.TestCase):
# def test_muldiv(self): # def test_muldiv(self):
# x, y, z = inputs() # x, y, z = inputs()
# a, b, c, d = more_inputs() # a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y) # # e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y) # # e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z) # # e = x / (y / z)
# # e = (x * y) / x # # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x) # # e = (x / y) * (y / z) * (z / x)
# # e = (a / b) * (b / c) * (c / d) # # e = (a / b) * (b / c) * (c / d)
# # e = (a * b) / (b * c) / (c * d) # # e = (a * b) / (b * c) / (c * d)
# # e = 2 * x / 2 # # e = 2 * x / 2
# # e = x / y / x # e = x / y / x
# g = Env([x, y, z, a, b, c, d], [e]) # g = Env([x, y, z, a, b, c, d], [e])
# print g # print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs) # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y # divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x # invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g) # Canonizer(mul, div, inv, mulfn, divfn, invfn).optimize(g)
# print g # print g
# def test_plusmin(self): # def test_plusmin(self):
...@@ -101,35 +91,53 @@ class _test_opts(unittest.TestCase): ...@@ -101,35 +91,53 @@ class _test_opts(unittest.TestCase):
# print g # print g
# def test_group_powers(self): # def test_group_powers(self):
# x, y, z = inputs() # x, y, z, a, b, c, d = floats('xyzabcd')
# a, b, c, d = more_inputs()
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z) # # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z) # # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z) # # e = pow(x, y) / pow(x, z)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y) # # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0) # # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z) # # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z) # # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y # # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0) # # e = x ** y / x ** (y - 1.0)
# e = exp(x) * a * exp(y) / exp(z) # # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e]) # g = Env([x, y, z, a, b, c, d], [e])
# print g # g.extend(gof.PrintListener(g))
# gof.ConstantFinder().optimize(g) # print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs) # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y # divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x # invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn, group_powers).optimize(g) # Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g # print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs) # addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y # subfn = lambda x, y: x - y
# negfn = lambda x: -x # negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g) # Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g # print g, g.orphans
# pow2one_float.optimize(g) # pow2one_float.optimize(g)
# pow2x_float.optimize(g) # pow2x_float.optimize(g)
# print g # print g, g.orphans
......
...@@ -17,20 +17,20 @@ def _numpy_checker(x, y): ...@@ -17,20 +17,20 @@ def _numpy_checker(x, y):
Checks if x.data and y.data have the same contents. Checks if x.data and y.data have the same contents.
Used in DualLinker to compare C version with Python version. Used in DualLinker to compare C version with Python version.
""" """
x, y = x.data, y.data x, y = x[0], y[0]
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}): def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}):
if grad is True: if grad is True:
grad = good grad = good
_op_class, _expected, _checks, _good, _bad_build, _bad_runtime, _grad = op_class, expected, checks, good, bad_build, bad_runtime, grad _op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad = op, expected, checks, good, bad_build, bad_runtime, grad
class Checker(unittest.TestCase): class Checker(unittest.TestCase):
op_class = _op_class op = _op
expected = staticmethod(_expected) expected = staticmethod(_expected)
checks = _checks checks = _checks
good = _good good = _good
...@@ -41,24 +41,25 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -41,24 +41,25 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
def test_good(self): def test_good(self):
for testname, inputs in self.good.items(): for testname, inputs in self.good.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [constant(input).type() for input in inputs]
try: try:
op = self.op_class(*inputs) node = self.op.make_node(*inputrs)
except: except:
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to build a %s instance with inputs %s" \ err_msg = "Test %s::%s: Error occurred while making a node with inputs %s" \
% (self.op_class.__name__, testname, self.op_class, inputs) % (self.op, testname, inputs)
value.args = value.args + (err_msg, ) value.args = value.args + (err_msg, )
raise type, value, traceback raise type, value, traceback
try: try:
f = Function(op.inputs, op.outputs, f = Function(node.inputs, node.outputs,
linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker), linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker),
unpack_single = False, unpack_single = False,
optimizer = None) optimizer = None)
except: except:
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to make a function out of %s" \ err_msg = "Test %s::%s: Error occurred while trying to make a Function" \
% (self.op_class.__name__, testname, op) % (self.op, testname)
value.args = value.args + (err_msg, ) value.args = value.args + (err_msg, )
raise type, value, traceback raise type, value, traceback
...@@ -68,8 +69,8 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -68,8 +69,8 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
results = f(*inputs) results = f(*inputs)
except: except:
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while calling %s on the inputs %s" \ err_msg = "Test %s::%s: Error occurred while calling the Function on the inputs %s" \
% (self.op_class.__name__, testname, op, inputs) % (self.op, testname, inputs)
value.args = value.args + (err_msg, ) value.args = value.args + (err_msg, )
raise type, value, traceback raise type, value, traceback
...@@ -77,45 +78,47 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -77,45 +78,47 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
expecteds = (expecteds, ) expecteds = (expecteds, )
for i, (result, expected) in enumerate(zip(results, expecteds)): for i, (result, expected) in enumerate(zip(results, expecteds)):
if result.dtype != expected.dtype or result.shape != expected.shape or numpy.any(abs(result - expected) > 1e-10): if result.dtype != expected.dtype or result.shape != expected.shape or numpy.any(abs(result - expected) > 1e-10):
self.fail("With data %s::%s: Output %s of %s gave the wrong value. With inputs %s, expected %s, got %s." self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op_class.__name__, testname, i, op, inputs, expected, result)) % (self.op, testname, i, inputs, expected, result))
for description, check in self.checks.items(): for description, check in self.checks.items():
if not check(inputs, results): if not check(inputs, results):
self.fail("With data %s::%s: %s failed the following check: %s (inputs were %s)" self.fail("Test %s::%s: Failed check: %s (inputs were %s)"
% (self.op_class.__name__, testname, op, description, inputs)) % (self.op, testname, description, inputs))
def test_bad_build(self): def test_bad_build(self):
for testname, inputs in self.bad_build.items(): for testname, inputs in self.bad_build.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [constant(input).type() for input in inputs]
try: try:
op = self.op_class(*inputs) node = self.op.make_node(*inputrs)
except: except:
return return
self.fail("With data %s::%s: %s was successfully instantiated on the following bad inputs: %s" self.fail("Test %s::%s: %s was successfully instantiated on the following bad inputs: %s"
% (self.op_class.__name__, testname, op, inputs)) % (self.op, testname, node, inputs))
def test_bad_runtime(self): def test_bad_runtime(self):
for testname, inputs in self.bad_runtime.items(): for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [constant(input).type() for input in inputs]
try: try:
op = self.op_class(*inputs) node = self.op.make_node(*inputrs)
except: except:
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to build a %s instance with inputs %s" \ err_msg = "Test %s::%s: Error occurred while trying to make a node with inputs %s" \
% (self.op_class.__name__, testname, self.op_class, inputs) % (self.op, testname, inputs)
value.args = value.args + (err_msg, ) value.args = value.args + (err_msg, )
raise type, value, traceback raise type, value, traceback
try: try:
f = Function(op.inputs, op.outputs, f = Function(node.inputs, node.outputs,
linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker), linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker),
unpack_single = False, unpack_single = False,
optimizer = None) optimizer = None)
except: except:
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to make a function out of %s" \ err_msg = "Test %s::%s: Error occurred while trying to make a Function" \
% (self.op_class.__name__, testname, op) % (self.op, testname)
value.args = value.args + (err_msg, ) value.args = value.args + (err_msg, )
raise type, value, traceback raise type, value, traceback
...@@ -124,18 +127,19 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -124,18 +127,19 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
except: except:
return return
self.fail("With data %s::%s: %s was successfully called on the following bad inputs: %s" self.fail("Test %s::%s: Successful call on the following bad inputs: %s"
% (self.op_class.__name__, testname, op, inputs)) % (self.op, testname, inputs))
def test_grad(self): def test_grad(self):
for testname, inputs in self.grad.items(): for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [constant(input).type() for input in inputs]
try: try:
verify_grad(self, self.op_class, inputs) verify_grad(self, self.op, inputs)
except: except:
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while computing the gradient for %s on the following inputs: %s" \ err_msg = "Test %s::%s: Error occurred while computing the gradient on the following inputs: %s" \
% (self.op_class.__name__, testname, self.op_class, inputs) % (self.op, testname, inputs)
value.args = value.args + (err_msg, ) value.args = value.args + (err_msg, )
raise type, value, traceback raise type, value, traceback
...@@ -157,8 +161,8 @@ def randint_ranged(min, max, shape): ...@@ -157,8 +161,8 @@ def randint_ranged(min, max, shape):
return numpy.random.random_integers(min, max, shape) return numpy.random.random_integers(min, max, shape)
def make_broadcast_tester(op_class, expected, checks = {}, **kwargs): def make_broadcast_tester(op, expected, checks = {}, **kwargs):
name = op_class.__name__ + "Tester" name = str(op) + "Tester"
if kwargs.has_key('inplace'): if kwargs.has_key('inplace'):
if kwargs['inplace']: if kwargs['inplace']:
_expected = expected _expected = expected
...@@ -166,7 +170,7 @@ def make_broadcast_tester(op_class, expected, checks = {}, **kwargs): ...@@ -166,7 +170,7 @@ def make_broadcast_tester(op_class, expected, checks = {}, **kwargs):
checks = dict(checks, checks = dict(checks,
inplace_check = lambda inputs, outputs: inputs[0] is outputs[0]) inplace_check = lambda inputs, outputs: inputs[0] is outputs[0])
del kwargs['inplace'] del kwargs['inplace']
return make_tester(name, op_class, expected, checks, **kwargs) return make_tester(name, op, expected, checks, **kwargs)
...@@ -189,28 +193,28 @@ _grad_broadcast_binary_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)), ...@@ -189,28 +193,28 @@ _grad_broadcast_binary_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)),
column = (rand(2, 3), rand(2, 1))) column = (rand(2, 3), rand(2, 1)))
AddTester = make_broadcast_tester(op_class = Add, AddTester = make_broadcast_tester(op = add,
expected = lambda *inputs: reduce(lambda x, y: x + y, inputs), expected = lambda *inputs: reduce(lambda x, y: x + y, inputs),
good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)), good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)),
four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)), four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)),
**_good_broadcast_binary_normal), **_good_broadcast_binary_normal),
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal) bad_runtime = _bad_runtime_broadcast_binary_normal)
AddInplaceTester = make_broadcast_tester(op_class = AddInplace, AddInplaceTester = make_broadcast_tester(op = add_inplace,
expected = lambda x, y: x + y, expected = lambda x, y: x + y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal, bad_runtime = _bad_runtime_broadcast_binary_normal,
inplace = True) inplace = True)
SubTester = make_broadcast_tester(op_class = Sub, SubTester = make_broadcast_tester(op = sub,
expected = lambda x, y: x - y, expected = lambda x, y: x - y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal, bad_runtime = _bad_runtime_broadcast_binary_normal,
grad = _grad_broadcast_binary_normal) grad = _grad_broadcast_binary_normal)
SubInplaceTester = make_broadcast_tester(op_class = SubInplace, SubInplaceTester = make_broadcast_tester(op = sub_inplace,
expected = lambda x, y: x - y, expected = lambda x, y: x - y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
...@@ -218,7 +222,7 @@ SubInplaceTester = make_broadcast_tester(op_class = SubInplace, ...@@ -218,7 +222,7 @@ SubInplaceTester = make_broadcast_tester(op_class = SubInplace,
grad = _grad_broadcast_binary_normal, grad = _grad_broadcast_binary_normal,
inplace = True) inplace = True)
MulTester = make_broadcast_tester(op_class = Mul, MulTester = make_broadcast_tester(op = mul,
expected = lambda *inputs: reduce(lambda x, y: x * y, inputs), expected = lambda *inputs: reduce(lambda x, y: x * y, inputs),
good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)), good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)),
four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)), four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)),
...@@ -228,7 +232,7 @@ MulTester = make_broadcast_tester(op_class = Mul, ...@@ -228,7 +232,7 @@ MulTester = make_broadcast_tester(op_class = Mul,
grad = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)), grad = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)),
four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)), four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)),
**_grad_broadcast_binary_normal)) **_grad_broadcast_binary_normal))
MulInplaceTester = make_broadcast_tester(op_class = MulInplace, MulInplaceTester = make_broadcast_tester(op = mul_inplace,
expected = lambda x, y: x * y, expected = lambda x, y: x * y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
...@@ -236,7 +240,7 @@ MulInplaceTester = make_broadcast_tester(op_class = MulInplace, ...@@ -236,7 +240,7 @@ MulInplaceTester = make_broadcast_tester(op_class = MulInplace,
grad = _grad_broadcast_binary_normal, grad = _grad_broadcast_binary_normal,
inplace = True) inplace = True)
DivTester = make_broadcast_tester(op_class = Div, DivTester = make_broadcast_tester(op = div,
expected = lambda x, y: x / y, expected = lambda x, y: x / y,
good = dict(same_shapes = (rand(2, 3), rand(2, 3)), good = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
...@@ -254,7 +258,7 @@ DivTester = make_broadcast_tester(op_class = Div, ...@@ -254,7 +258,7 @@ DivTester = make_broadcast_tester(op_class = Div,
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
row = (rand(2, 3), rand(1, 3)), row = (rand(2, 3), rand(1, 3)),
column = (rand(2, 3), rand(2, 1)))) column = (rand(2, 3), rand(2, 1))))
DivInplaceTester = make_broadcast_tester(op_class = DivInplace, DivInplaceTester = make_broadcast_tester(op = div_inplace,
expected = lambda x, y: x / y, expected = lambda x, y: x / y,
good = dict(same_shapes = (rand(2, 3), rand(2, 3)), good = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
...@@ -269,7 +273,7 @@ DivInplaceTester = make_broadcast_tester(op_class = DivInplace, ...@@ -269,7 +273,7 @@ DivInplaceTester = make_broadcast_tester(op_class = DivInplace,
column = (rand(2, 3), rand(2, 1))), column = (rand(2, 3), rand(2, 1))),
inplace = True) inplace = True)
PowTester = make_broadcast_tester(op_class = Pow, PowTester = make_broadcast_tester(op = pow,
expected = lambda x, y: x ** y, expected = lambda x, y: x ** y,
good = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))), good = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))),
scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))), scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))),
...@@ -281,7 +285,7 @@ PowTester = make_broadcast_tester(op_class = Pow, ...@@ -281,7 +285,7 @@ PowTester = make_broadcast_tester(op_class = Pow,
row = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 3))), row = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 3))),
column = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 1)))) column = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 1))))
) )
PowInplaceTester = make_broadcast_tester(op_class = PowInplace, PowInplaceTester = make_broadcast_tester(op = pow_inplace,
expected = lambda x, y: x ** y, expected = lambda x, y: x ** y,
good = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))), good = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))),
scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))), scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))),
...@@ -302,49 +306,49 @@ _good_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),), ...@@ -302,49 +306,49 @@ _good_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),),
_grad_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),)) _grad_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),))
AbsTester = make_broadcast_tester(op_class = Abs, AbsTester = make_broadcast_tester(op = tensor._abs,
expected = lambda x: abs(x), expected = lambda x: abs(x),
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
AbsInplaceTester = make_broadcast_tester(op_class = AbsInplace, AbsInplaceTester = make_broadcast_tester(op = abs_inplace,
expected = lambda x: abs(x), expected = lambda x: abs(x),
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
NegTester = make_broadcast_tester(op_class = Neg, NegTester = make_broadcast_tester(op = neg,
expected = lambda x: -x, expected = lambda x: -x,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
NegInplaceTester = make_broadcast_tester(op_class = NegInplace, NegInplaceTester = make_broadcast_tester(op = neg_inplace,
expected = lambda x: -x, expected = lambda x: -x,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
SgnTester = make_broadcast_tester(op_class = Sgn, SgnTester = make_broadcast_tester(op = sgn,
expected = numpy.sign, expected = numpy.sign,
good = _good_broadcast_unary_normal) good = _good_broadcast_unary_normal)
SgnInplaceTester = make_broadcast_tester(op_class = SgnInplace, SgnInplaceTester = make_broadcast_tester(op = sgn_inplace,
expected = numpy.sign, expected = numpy.sign,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
inplace = True) inplace = True)
SqrTester = make_broadcast_tester(op_class = Sqr, SqrTester = make_broadcast_tester(op = sqr,
expected = numpy.square, expected = numpy.square,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
SqrInplaceTester = make_broadcast_tester(op_class = SqrInplace, SqrInplaceTester = make_broadcast_tester(op = sqr_inplace,
expected = numpy.square, expected = numpy.square,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
ExpTester = make_broadcast_tester(op_class = Exp, ExpTester = make_broadcast_tester(op = exp,
expected = numpy.exp, expected = numpy.exp,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
ExpInplaceTester = make_broadcast_tester(op_class = ExpInplace, ExpInplaceTester = make_broadcast_tester(op = exp_inplace,
expected = numpy.exp, expected = numpy.exp,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -356,31 +360,31 @@ _good_broadcast_unary_positive = dict(normal = (rand_ranged(0.001, 5, (2, 3)),), ...@@ -356,31 +360,31 @@ _good_broadcast_unary_positive = dict(normal = (rand_ranged(0.001, 5, (2, 3)),),
_grad_broadcast_unary_positive = dict(normal = (rand_ranged(0.001, 5, (2, 3)),)) _grad_broadcast_unary_positive = dict(normal = (rand_ranged(0.001, 5, (2, 3)),))
LogTester = make_broadcast_tester(op_class = Log, LogTester = make_broadcast_tester(op = log,
expected = numpy.log, expected = numpy.log,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive) grad = _grad_broadcast_unary_positive)
LogInplaceTester = make_broadcast_tester(op_class = LogInplace, LogInplaceTester = make_broadcast_tester(op = log_inplace,
expected = numpy.log, expected = numpy.log,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive, grad = _grad_broadcast_unary_positive,
inplace = True) inplace = True)
Log2Tester = make_broadcast_tester(op_class = Log2, Log2Tester = make_broadcast_tester(op = log2,
expected = numpy.log2, expected = numpy.log2,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive) grad = _grad_broadcast_unary_positive)
Log2InplaceTester = make_broadcast_tester(op_class = Log2Inplace, Log2InplaceTester = make_broadcast_tester(op = log2_inplace,
expected = numpy.log2, expected = numpy.log2,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive, grad = _grad_broadcast_unary_positive,
inplace = True) inplace = True)
SqrtTester = make_broadcast_tester(op_class = Sqrt, SqrtTester = make_broadcast_tester(op = sqrt,
expected = numpy.sqrt, expected = numpy.sqrt,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive) grad = _grad_broadcast_unary_positive)
SqrtInplaceTester = make_broadcast_tester(op_class = SqrtInplace, SqrtInplaceTester = make_broadcast_tester(op = sqrt_inplace,
expected = numpy.sqrt, expected = numpy.sqrt,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive, grad = _grad_broadcast_unary_positive,
...@@ -394,33 +398,33 @@ _good_broadcast_unary_wide = dict(normal = (rand_ranged(-1000, 1000, (2, 3)),), ...@@ -394,33 +398,33 @@ _good_broadcast_unary_wide = dict(normal = (rand_ranged(-1000, 1000, (2, 3)),),
_grad_broadcast_unary_wide = dict(normal = (rand_ranged(-1000, 1000, (2, 3)),)) _grad_broadcast_unary_wide = dict(normal = (rand_ranged(-1000, 1000, (2, 3)),))
SinTester = make_broadcast_tester(op_class = Sin, SinTester = make_broadcast_tester(op = sin,
expected = numpy.sin, expected = numpy.sin,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide) grad = _grad_broadcast_unary_wide)
SinInplaceTester = make_broadcast_tester(op_class = SinInplace, SinInplaceTester = make_broadcast_tester(op = sin_inplace,
expected = numpy.sin, expected = numpy.sin,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide, grad = _grad_broadcast_unary_wide,
inplace = True) inplace = True)
CosTester = make_broadcast_tester(op_class = Cos, CosTester = make_broadcast_tester(op = cos,
expected = numpy.cos, expected = numpy.cos,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide) grad = _grad_broadcast_unary_wide)
CosInplaceTester = make_broadcast_tester(op_class = CosInplace, CosInplaceTester = make_broadcast_tester(op = cos_inplace,
expected = numpy.cos, expected = numpy.cos,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide, grad = _grad_broadcast_unary_wide,
inplace = True) inplace = True)
TanTester = make_broadcast_tester(op_class = Tan, TanTester = make_broadcast_tester(op = tan,
expected = numpy.tan, expected = numpy.tan,
good = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),), good = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),),
shifted = (rand_ranged(3.15, 6.28, (2, 3)),)), shifted = (rand_ranged(3.15, 6.28, (2, 3)),)),
grad = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),), grad = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),),
shifted = (rand_ranged(3.15, 6.28, (2, 3)),))) shifted = (rand_ranged(3.15, 6.28, (2, 3)),)))
TanInplaceTester = make_broadcast_tester(op_class = TanInplace, TanInplaceTester = make_broadcast_tester(op = tan_inplace,
expected = numpy.tan, expected = numpy.tan,
good = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),), good = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),),
shifted = (rand_ranged(3.15, 6.28, (2, 3)),)), shifted = (rand_ranged(3.15, 6.28, (2, 3)),)),
...@@ -429,31 +433,31 @@ TanInplaceTester = make_broadcast_tester(op_class = TanInplace, ...@@ -429,31 +433,31 @@ TanInplaceTester = make_broadcast_tester(op_class = TanInplace,
inplace = True) inplace = True)
CoshTester = make_broadcast_tester(op_class = Cosh, CoshTester = make_broadcast_tester(op = cosh,
expected = numpy.cosh, expected = numpy.cosh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
CoshInplaceTester = make_broadcast_tester(op_class = CoshInplace, CoshInplaceTester = make_broadcast_tester(op = cosh_inplace,
expected = numpy.cosh, expected = numpy.cosh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
SinhTester = make_broadcast_tester(op_class = Sinh, SinhTester = make_broadcast_tester(op = sinh,
expected = numpy.sinh, expected = numpy.sinh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
SinhInplaceTester = make_broadcast_tester(op_class = SinhInplace, SinhInplaceTester = make_broadcast_tester(op = sinh_inplace,
expected = numpy.sinh, expected = numpy.sinh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
TanhTester = make_broadcast_tester(op_class = Tanh, TanhTester = make_broadcast_tester(op = tanh,
expected = numpy.tanh, expected = numpy.tanh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
TanhInplaceTester = make_broadcast_tester(op_class = TanhInplace, TanhInplaceTester = make_broadcast_tester(op = tanh_inplace,
expected = numpy.tanh, expected = numpy.tanh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -461,15 +465,15 @@ TanhInplaceTester = make_broadcast_tester(op_class = TanhInplace, ...@@ -461,15 +465,15 @@ TanhInplaceTester = make_broadcast_tester(op_class = TanhInplace,
DotTester = make_tester(name = 'DotTester', # DotTester = make_tester(name = 'DotTester',
op_class = Dot, # op = Dot,
expected = lambda x, y: numpy.dot(x, y), # expected = lambda x, y: numpy.dot(x, y),
checks = {}, # checks = {},
good = dict(correct1 = (rand(5, 7), rand(7, 5)), # good = dict(correct1 = (rand(5, 7), rand(7, 5)),
correct2 = (rand(5, 7), rand(7, 9))), # correct2 = (rand(5, 7), rand(7, 9))),
bad_build = dict(), # bad_build = dict(),
bad_runtime = dict(bad1 = (rand(5, 7), rand(5, 7)), # bad_runtime = dict(bad1 = (rand(5, 7), rand(5, 7)),
bad2 = (rand(5, 7), rand(8, 3)))) # bad2 = (rand(5, 7), rand(8, 3))))
...@@ -477,13 +481,14 @@ DotTester = make_tester(name = 'DotTester', ...@@ -477,13 +481,14 @@ DotTester = make_tester(name = 'DotTester',
# rationale: it's tricky, and necessary everytime you want to verify # rationale: it's tricky, and necessary everytime you want to verify
# gradient numerically # gradient numerically
def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001, tol=0.0001): def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, tol=0.0001):
"""testcase.failUnless( analytic gradient matches finite-diff gradient) """ """testcase.failUnless( analytic gradient matches finite-diff gradient) """
pt = [numpy.asarray(p) for p in pt] pt = [numpy.asarray(p) for p in pt]
for test_num in xrange(n_tests): for test_num in xrange(n_tests):
tensor_pt = [astensor(p,name='input %i'%i) for i,p in enumerate(pt)] # tensor_pt = [as_tensor(p,name='input %i'%i) for i,p in enumerate(pt)]
o = op_cls(*[tpt.copy() for tpt in tensor_pt]) tensor_pt = [constant(p).type('input %i'%i) for i,p in enumerate(pt)]
o = op.make_node(*[tpt.copy() for tpt in tensor_pt])
if hasattr(o, 'outputs'): if hasattr(o, 'outputs'):
o_outputs = o.outputs o_outputs = o.outputs
else: else:
...@@ -497,7 +502,7 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001 ...@@ -497,7 +502,7 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001
o_fn = Function(tensor_pt, o_outputs) o_fn = Function(tensor_pt, o_outputs)
o_fn_out = o_fn(*pt) o_fn_out = o_fn(*pt)
random_projection = rng.rand(*o_fn_out.shape) random_projection = rng.rand(*o_fn_out.shape)
t_r = astensor(random_projection) t_r = as_tensor(random_projection)
#random projection of o onto t_r #random projection of o onto t_r
cost = sum(t_r * o_outputs[0]) cost = sum(t_r * o_outputs[0])
...@@ -505,7 +510,7 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001 ...@@ -505,7 +510,7 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001
num_grad = gradient.numeric_grad(cost_fn, pt) num_grad = gradient.numeric_grad(cost_fn, pt)
symbolic_grad = gradient.grad(cost, tensor_pt,astensor(1.0,name='g_cost')) symbolic_grad = gradient.grad(cost, tensor_pt,as_tensor(1.0,name='g_cost'))
if 0: if 0:
print '-------' print '-------'
print '----------' print '----------'
...@@ -532,887 +537,887 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001 ...@@ -532,887 +537,887 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001
verify_grad.E_grad = 'gradient error exceeded tolerance' verify_grad.E_grad = 'gradient error exceeded tolerance'
#useful mostly for unit tests # #useful mostly for unit tests
def _approx_eq(a,b,eps=1.0e-9): # def _approx_eq(a,b,eps=1.0e-9):
a = numpy.asarray(a) # a = numpy.asarray(a)
b = numpy.asarray(b) # b = numpy.asarray(b)
if a.shape != b.shape: # if a.shape != b.shape:
if _approx_eq.debug: # if _approx_eq.debug:
print a.shape, b.shape # print a.shape, b.shape
return False # return False
if numpy.max(numpy.abs(a-b)) >= eps: # if numpy.max(numpy.abs(a-b)) >= eps:
if _approx_eq.debug: # if _approx_eq.debug:
print a, b # print a, b
return False # return False
return True # return True
_approx_eq.debug = 0 # _approx_eq.debug = 0
def check_eq(self, node_in, node_out, arg_in, arg_out): # def check_eq(self, node_in, node_out, arg_in, arg_out):
fn = Function([node_in], [node_out]) # fn = Function([node_in], [node_out])
self.failUnless( numpy.all(fn(arg_in) == arg_out), (arg_in, arg_out)) # self.failUnless( numpy.all(fn(arg_in) == arg_out), (arg_in, arg_out))
def check_eq2(self, inputs, output, args_in, arg_out): # def check_eq2(self, inputs, output, args_in, arg_out):
fn = Function(inputs, [output]) # fn = Function(inputs, [output])
val = fn(*args_in) # val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out)) # self.failUnless( numpy.all(val == arg_out), (val, arg_out))
def check_eq2_c(self, inputs, output, args_in, arg_out): # def check_eq2_c(self, inputs, output, args_in, arg_out):
fn = Function(inputs, [output], linker_cls = gof.CLinker) # fn = Function(inputs, [output], linker_cls = gof.CLinker)
val = fn(*args_in) # val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out)) # self.failUnless( numpy.all(val == arg_out), (val, arg_out))
def check_eq2_both(self, inputs, output, args_in, arg_out): # def check_eq2_both(self, inputs, output, args_in, arg_out):
fn = Function(inputs, [output], linker_cls = lambda env: gof.DualLinker(env, _numpy_checker)) # fn = Function(inputs, [output], linker_cls = lambda env: gof.DualLinker(env, _numpy_checker))
val = fn(*args_in) # val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out)) # self.failUnless( numpy.all(val == arg_out), (val, arg_out))
class T_argmax(unittest.TestCase): # class T_argmax(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed(123784) # numpy.random.seed(123784)
Argmax.debug = 0 # Argmax.debug = 0
def test0(self): # def test0(self):
n = astensor(5.0) # n = astensor(5.0)
v,i = eval_outputs(argmax(n)) # v,i = eval_outputs(argmax(n))
self.failUnless(v == 5.0) # self.failUnless(v == 5.0)
self.failUnless(i == 0) # self.failUnless(i == 0)
def test1(self): # def test1(self):
n = astensor([1,2,3,2,-6]) # n = astensor([1,2,3,2,-6])
v,i = eval_outputs(argmax(n)) # v,i = eval_outputs(argmax(n))
self.failUnless(v == 3) # self.failUnless(v == 3)
self.failUnless(i == 2) # self.failUnless(i == 2)
def test2(self): # def test2(self):
n = astensor(numpy.random.rand(2,3)) # n = astensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n)) # v,i = eval_outputs(argmax(n))
self.failUnless(numpy.all(i == [0,1])) # self.failUnless(numpy.all(i == [0,1]))
def test2b(self): # def test2b(self):
n = astensor(numpy.random.rand(2,3)) # n = astensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,axis=0)) # v,i = eval_outputs(argmax(n,axis=0))
self.failUnless(numpy.all(i == [0,1,1])) # self.failUnless(numpy.all(i == [0,1,1]))
def test2_invalid(self): # def test2_invalid(self):
n = astensor(numpy.random.rand(2,3)) # n = astensor(numpy.random.rand(2,3))
try: # try:
eval_outputs(argmax(n,axis=3)) # eval_outputs(argmax(n,axis=3))
except ValueError, e: # except ValueError, e:
return # return
self.fail() # self.fail()
def test2_invalid_neg(self): # def test2_invalid_neg(self):
n = astensor(numpy.random.rand(2,3)) # n = astensor(numpy.random.rand(2,3))
try: # try:
eval_outputs(argmax(n,axis=-3)) # eval_outputs(argmax(n,axis=-3))
except ValueError, e: # except ValueError, e:
return # return
self.fail() # self.fail()
def test2_valid_neg(self): # def test2_valid_neg(self):
n = astensor(numpy.random.rand(2,3)) # n = astensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,axis=-1)) # v,i = eval_outputs(argmax(n,axis=-1))
self.failUnless(v.shape == (2,)) # self.failUnless(v.shape == (2,))
v,i = eval_outputs(argmax(n,axis=-2)) # v,i = eval_outputs(argmax(n,axis=-2))
self.failUnless(v.shape == (3,)) # self.failUnless(v.shape == (3,))
def test3(self): # def test3(self):
n = astensor(numpy.random.rand(2,3,4)) # n = astensor(numpy.random.rand(2,3,4))
v,i = eval_outputs(argmax(n,axis=0)) # v,i = eval_outputs(argmax(n,axis=0))
self.failUnless(v.shape == (3,4)) # self.failUnless(v.shape == (3,4))
self.failUnless(i.shape == (3,4)) # self.failUnless(i.shape == (3,4))
v,i = eval_outputs(argmax(n,axis=1)) # v,i = eval_outputs(argmax(n,axis=1))
self.failUnless(v.shape == (2,4)) # self.failUnless(v.shape == (2,4))
self.failUnless(i.shape == (2,4)) # self.failUnless(i.shape == (2,4))
v,i = eval_outputs(argmax(n,axis=2)) # v,i = eval_outputs(argmax(n,axis=2))
self.failUnless(v.shape == (2,3)) # self.failUnless(v.shape == (2,3))
self.failUnless(i.shape == (2,3)) # self.failUnless(i.shape == (2,3))
class T_transpose(unittest.TestCase): # class T_transpose(unittest.TestCase):
def test0(self): # def test0(self):
n = astensor(numpy.ones(())) # n = astensor(numpy.ones(()))
t = transpose(n) # t = transpose(n)
self.failUnless(t.owner.__class__ is TransposeInplace) # self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) # f = Function([n], [t])
tval = f(n.data) # tval = f(n.data)
self.failUnless(tval.shape == n.data.shape) # self.failUnless(tval.shape == n.data.shape)
#test aliasing # #test aliasing
tval += 55.0 # tval += 55.0
self.failUnless(n.data == 1.0) # self.failUnless(n.data == 1.0)
def test1(self): # def test1(self):
n = astensor(numpy.ones(5)) # n = astensor(numpy.ones(5))
t = transpose(n) # t = transpose(n)
self.failUnless(t.owner.__class__ is TransposeInplace) # self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) # f = Function([n], [t])
tval = f(n.data) # tval = f(n.data)
self.failUnless(tval.shape == n.data.shape) # self.failUnless(tval.shape == n.data.shape)
#test aliasing # #test aliasing
tval += 55.0 # tval += 55.0
self.failUnless(n.data[0] == 1.0) # self.failUnless(n.data[0] == 1.0)
def test2(self): # def test2(self):
n = astensor(numpy.ones((5,3))) # n = astensor(numpy.ones((5,3)))
t = transpose(n) # t = transpose(n)
self.failUnless(t.owner.__class__ is TransposeInplace) # self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) # f = Function([n], [t])
tval = f(n.data) # tval = f(n.data)
self.failUnless(tval.shape == (3,5)) # self.failUnless(tval.shape == (3,5))
#test aliasing # #test aliasing
tval += 55.0 # tval += 55.0
self.failUnless(n.data[0,0] == 1.0) # self.failUnless(n.data[0,0] == 1.0)
def test3(self): # def test3(self):
"""Test transpose of tensor, inplace version""" # """Test transpose of tensor, inplace version"""
n = astensor(numpy.ones((5,3,2))) # n = astensor(numpy.ones((5,3,2)))
t = transpose_inplace(n) # t = transpose_inplace(n)
self.failUnless(t.owner.__class__ is TransposeInplace) # self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) # f = Function([n], [t])
tval = f(n.data) # tval = f(n.data)
self.failUnless(tval.shape == (2,3,5)) # self.failUnless(tval.shape == (2,3,5))
#test aliasing # #test aliasing
tval += 55.0 # tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0) # self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self): # def test_grad(self):
verify_grad(self, TransposeInplace, [numpy.random.rand(2, 3)]) # verify_grad(self, TransposeInplace, [numpy.random.rand(2, 3)])
verify_grad(self, TransposeInplace, [numpy.ones(3)]) # verify_grad(self, TransposeInplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): # class T_subtensor(unittest.TestCase):
def test0_err_invalid(self): # def test0_err_invalid(self):
#it is impossible to retrieve a view of a 0-d tensor # #it is impossible to retrieve a view of a 0-d tensor
n = astensor(numpy.ones(())) # n = astensor(numpy.ones(()))
try: # try:
t = n[0] # t = n[0]
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid) # self.failUnless(e[0] is Subtensor.e_invalid)
return # return
self.fail() # self.fail()
def test1_err_bounds(self): # def test1_err_bounds(self):
n = astensor(numpy.ones(3)) # n = astensor(numpy.ones(3))
t = n[7] # t = n[7]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
try: # try:
tval = eval_outputs([t]) # tval = eval_outputs([t])
except Exception, e: # except Exception, e:
if e[0] != 'index out of bounds': # if e[0] != 'index out of bounds':
raise # raise
return # return
self.fail() # self.fail()
def test1_ok_range_finite(self): # def test1_ok_range_finite(self):
n = astensor(numpy.ones(3)*5) # n = astensor(numpy.ones(3)*5)
t = n[0:2] # t = n[0:2]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) # self.failUnless(tval[1] == 5.0)
def test2_ok_range_finite(self): # def test2_ok_range_finite(self):
n = astensor(numpy.ones((3,4))*5) # n = astensor(numpy.ones((3,4))*5)
t = n[0:2,3] # t = n[0:2,3]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) # self.failUnless(tval[1] == 5.0)
def test1_err_invalid(self): # def test1_err_invalid(self):
n = astensor(numpy.ones(1)) # n = astensor(numpy.ones(1))
try: # try:
t = n[0,0] # t = n[0,0]
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid) # self.failUnless(e[0] is Subtensor.e_invalid)
return # return
self.fail() # self.fail()
def test1_ok_elem(self): # def test1_ok_elem(self):
n = astensor(numpy.ones(1)*5) # n = astensor(numpy.ones(1)*5)
t = n[0] # t = n[0]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == ()) # self.failUnless(tval.shape == ())
self.failUnless(tval == 5.0) # self.failUnless(tval == 5.0)
def test1_ok_range_infinite(self): # def test1_ok_range_infinite(self):
n = astensor(numpy.ones(3)*5) # n = astensor(numpy.ones(3)*5)
t = n[1:] # t = n[1:]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) # self.failUnless(tval[1] == 5.0)
def test1_ok_strided(self): # def test1_ok_strided(self):
n = astensor(numpy.ones(5)*5) # n = astensor(numpy.ones(5)*5)
t = n[1::2] # t = n[1::2]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) # self.failUnless(tval[1] == 5.0)
tval = eval_outputs([n[0:-1:2]]) #0 to 1 from the end stepping by 2 # tval = eval_outputs([n[0:-1:2]]) #0 to 1 from the end stepping by 2
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(tval[1] == 5.0) # self.failUnless(tval[1] == 5.0)
def test2_err_bounds0(self): # def test2_err_bounds0(self):
n = astensor(numpy.ones((2,3))*5) # n = astensor(numpy.ones((2,3))*5)
t = n[0,4] # t = n[0,4]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
try: # try:
tval = eval_outputs([t]) # tval = eval_outputs([t])
except IndexError, e: # except IndexError, e:
return # return
self.fail() # self.fail()
def test2_err_bounds1(self): # def test2_err_bounds1(self):
n = astensor(numpy.ones((2,3))*5) # n = astensor(numpy.ones((2,3))*5)
t = n[4:5,2] # t = n[4:5,2]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
try: # try:
tval = eval_outputs([t]) # tval = eval_outputs([t])
except Exception, e: # except Exception, e:
if e[0] != 'index out of bounds': # if e[0] != 'index out of bounds':
raise # raise
def test2_ok_elem(self): # def test2_ok_elem(self):
n = astensor(numpy.asarray(range(6)).reshape((2,3))) # n = astensor(numpy.asarray(range(6)).reshape((2,3)))
t = n[0,2] # t = n[0,2]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == ()) # self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 2)) # self.failUnless(numpy.all(tval == 2))
def test2_ok_row(self): # def test2_ok_row(self):
n = astensor(numpy.asarray(range(6)).reshape((2,3))) # n = astensor(numpy.asarray(range(6)).reshape((2,3)))
t = n[1] # t = n[1]
self.failIf(any(n.broadcastable)) # self.failIf(any(n.broadcastable))
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (3,)) # self.failUnless(tval.shape == (3,))
self.failUnless(numpy.all(tval == [3,4,5])) # self.failUnless(numpy.all(tval == [3,4,5]))
def test2_ok_col(self): # def test2_ok_col(self):
n = astensor(numpy.ones((2,3))*5) # n = astensor(numpy.ones((2,3))*5)
t = n[:,0] # t = n[:,0]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
self.failIf(any(n.broadcastable)) # self.failIf(any(n.broadcastable))
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0)) # self.failUnless(numpy.all(tval == 5.0))
def test2_ok_rows_finite(self): # def test2_ok_rows_finite(self):
n = astensor(numpy.ones((4,3))*5) # n = astensor(numpy.ones((4,3))*5)
t = n[1:3,0] # t = n[1:3,0]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,)) # self.failUnless(tval.shape == (2,))
self.failUnless(numpy.all(tval == 5.0)) # self.failUnless(numpy.all(tval == 5.0))
def test2_ok_cols_infinite(self): # def test2_ok_cols_infinite(self):
n = astensor(numpy.asarray(range(12)).reshape((4,3))) # n = astensor(numpy.asarray(range(12)).reshape((4,3)))
t = n[1,2:] # t = n[1,2:]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (1,)) # self.failUnless(tval.shape == (1,))
self.failUnless(numpy.all(tval == 5)) # self.failUnless(numpy.all(tval == 5))
def test2_ok_strided(self): # def test2_ok_strided(self):
n = astensor(numpy.asarray(range(20)).reshape((4,5))) # n = astensor(numpy.asarray(range(20)).reshape((4,5)))
t = n[1:4:2,1:5:2] # t = n[1:4:2,1:5:2]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == (2,2)) # self.failUnless(tval.shape == (2,2))
self.failUnless(numpy.all(tval == [[6, 8],[16, 18]])) # self.failUnless(numpy.all(tval == [[6, 8],[16, 18]]))
def test3_ok_mat(self): # def test3_ok_mat(self):
n = astensor(numpy.asarray(range(24)).reshape((2,3,4))) # n = astensor(numpy.asarray(range(24)).reshape((2,3,4)))
t = n[0,0,0] # t = n[0,0,0]
self.failUnless(t.owner.__class__ is Subtensor) # self.failUnless(t.owner.__class__ is Subtensor)
tval = eval_outputs([t]) # tval = eval_outputs([t])
self.failUnless(tval.shape == ()) # self.failUnless(tval.shape == ())
self.failUnless(numpy.all(tval == 0)) # self.failUnless(numpy.all(tval == 0))
class T_add(unittest.TestCase): # class T_add(unittest.TestCase):
def test_complex_all_ops(self): # def test_complex_all_ops(self):
for nbits in (64, 128): # for nbits in (64, 128):
a = astensor(numpy.ones(3, dtype='complex%i' % nbits)+0.5j) # a = astensor(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
b = astensor(numpy.ones(3, dtype='complex%i' % nbits)+1.5j) # b = astensor(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
tests = (("+", lambda x,y: x+y), # tests = (("+", lambda x,y: x+y),
("-", lambda x,y: x-y), # ("-", lambda x,y: x-y),
("*", lambda x,y: x*y), # ("*", lambda x,y: x*y),
("/", lambda x,y: x/y)) # ("/", lambda x,y: x/y))
for s, fn in tests: # for s, fn in tests:
f = Function([a,b], [fn(a, b)], linker_cls = gof.CLinker) # f = Function([a,b], [fn(a, b)], linker_cls = gof.CLinker)
self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data))) # self.failUnless(numpy.all(fn(a.data, b.data) == f(a.data, b.data)))
def test_grad_scalar_l(self): # def test_grad_scalar_l(self):
verify_grad(self, Add, [numpy.asarray([3.0]), numpy.random.rand(3)]) # verify_grad(self, Add, [numpy.asarray([3.0]), numpy.random.rand(3)])
def test_grad_scalar_r(self): # def test_grad_scalar_r(self):
verify_grad(self, Add, [numpy.random.rand(3), numpy.asarray([3.0])]) # verify_grad(self, Add, [numpy.random.rand(3), numpy.asarray([3.0])])
def test_grad_row(self): # def test_grad_row(self):
verify_grad(self, Add, [numpy.random.rand(3, 5), numpy.random.rand(1, 5)]) # verify_grad(self, Add, [numpy.random.rand(3, 5), numpy.random.rand(1, 5)])
def test_grad_col(self): # def test_grad_col(self):
verify_grad(self, Add, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)]) # verify_grad(self, Add, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
class T_abs(unittest.TestCase): # class T_abs(unittest.TestCase):
def test_impl(self): # def test_impl(self):
t = astensor(1.0) # t = astensor(1.0)
check_eq(self, t, abs(t), 1.0, 1.0) # check_eq(self, t, abs(t), 1.0, 1.0)
check_eq(self, t, abs(t), -1.0, 1.0) # check_eq(self, t, abs(t), -1.0, 1.0)
for shape in (2,), (3,4): # for shape in (2,), (3,4):
t = astensor(numpy.ones(shape)) # t = astensor(numpy.ones(shape))
d = numpy.random.rand(*shape)*2-1.0 # d = numpy.random.rand(*shape)*2-1.0
check_eq(self, t, abs(t), d, abs(d)) # check_eq(self, t, abs(t), d, abs(d))
check_eq(self, t, abs(t), -d, abs(-d)) # check_eq(self, t, abs(t), -d, abs(-d))
def test_grad(self): # def test_grad(self):
verify_grad(self, Abs, [numpy.ones(())]) # verify_grad(self, Abs, [numpy.ones(())])
verify_grad(self, Abs, [numpy.ones(3)]) # verify_grad(self, Abs, [numpy.ones(3)])
class AbsBadGrad(Abs): # class AbsBadGrad(Abs):
def grad(self, (x, ), (gz, )): # def grad(self, (x, ), (gz, )):
return mul(gz * sgn(x),0.9), # return mul(gz * sgn(x),0.9),
def test_badgrad(self): # def test_badgrad(self):
try: # try:
verify_grad(self, T_abs.AbsBadGrad, [numpy.ones(())]) # verify_grad(self, T_abs.AbsBadGrad, [numpy.ones(())])
except Exception, e: # except Exception, e:
self.failUnless(str(e) == verify_grad.E_grad, str(e)) # self.failUnless(str(e) == verify_grad.E_grad, str(e))
return # return
self.fail() # self.fail()
class T_fill(unittest.TestCase): # class T_fill(unittest.TestCase):
def test0(self): # def test0(self):
t = fill(numpy.asarray([1,2,3]), 9) # t = fill(numpy.asarray([1,2,3]), 9)
self.failUnless(t.owner.__class__ == Fill) # self.failUnless(t.owner.__class__ == Fill)
o = t.owner # o = t.owner
self.failUnless(o.inputs[0].broadcastable == (0,)) # self.failUnless(o.inputs[0].broadcastable == (0,))
# self.failUnless(o.inputs[0].dtype[0:3] == 'int') # # self.failUnless(o.inputs[0].dtype[0:3] == 'int')
self.failUnless(o.inputs[1].broadcastable == (1,)) # self.failUnless(o.inputs[1].broadcastable == (1,))
# self.failUnless(o.inputs[1].dtype[0:3] == 'flo') # # self.failUnless(o.inputs[1].dtype[0:3] == 'flo')
self.failUnless(o.outputs[0].broadcastable == (0,)) # self.failUnless(o.outputs[0].broadcastable == (0,))
# self.failUnless(o.outputs[0].dtype[0:3] == 'flo') # # self.failUnless(o.outputs[0].dtype[0:3] == 'flo')
self.failUnless(numpy.all(eval_outputs([t]) == [9,9,9])) # self.failUnless(numpy.all(eval_outputs([t]) == [9,9,9]))
def test1(self): # def test1(self):
x = astensor(numpy.ones((4,5))) # x = astensor(numpy.ones((4,5)))
l = ones_like(x[:,0:1]) # l = ones_like(x[:,0:1])
r = ones_like(x[0:1,:]) # r = ones_like(x[0:1,:])
xx = x + dot(l,r) # xx = x + dot(l,r)
self.failUnless(numpy.mean(eval_outputs([xx]) == 2.0)) # self.failUnless(numpy.mean(eval_outputs([xx]) == 2.0))
class T_sum(unittest.TestCase): # class T_sum(unittest.TestCase):
def test_impl(self): # def test_impl(self):
t = astensor(0.0) # t = astensor(0.0)
check_eq(self, t, Sum(t).out, 1.0, 1.0) # check_eq(self, t, Sum(t).out, 1.0, 1.0)
check_eq(self, t, Sum(t).out, -1.0, -1.0) # check_eq(self, t, Sum(t).out, -1.0, -1.0)
t = astensor([0.0, 0.0]) # t = astensor([0.0, 0.0])
d = numpy.asarray([-0.4, 1.2]) # d = numpy.asarray([-0.4, 1.2])
check_eq(self, t, Sum(t).out, d, numpy.sum(d)) # check_eq(self, t, Sum(t).out, d, numpy.sum(d))
check_eq(self, t, Sum(t).out, -d, -numpy.sum(d)) # check_eq(self, t, Sum(t).out, -d, -numpy.sum(d))
class T_mul(unittest.TestCase): # class T_mul(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed([1,2,3,4]) # numpy.random.seed([1,2,3,4])
def test_elemwise(self): # def test_elemwise(self):
a = astensor(0.0) # a = astensor(0.0)
b = astensor(0.0) # b = astensor(0.0)
check_eq2_both(self, [a,b], mul(a,b), [3.0, 4.0], 12.0) # check_eq2_both(self, [a,b], mul(a,b), [3.0, 4.0], 12.0)
check_eq2_both(self, [a,b], mul(b,a), [-1.0,2.0], -2.0) # check_eq2_both(self, [a,b], mul(b,a), [-1.0,2.0], -2.0)
a = astensor(numpy.ones(2)) # a = astensor(numpy.ones(2))
b = astensor(numpy.ones(2)) # b = astensor(numpy.ones(2))
aa = numpy.asarray([-0.5, 4.0]) # aa = numpy.asarray([-0.5, 4.0])
bb = numpy.asarray([-0.5, 2.0]) # bb = numpy.asarray([-0.5, 2.0])
check_eq2_both(self, [a,b], mul(a,b), [aa,bb], numpy.asarray([0.25, 8.0])) # check_eq2_both(self, [a,b], mul(a,b), [aa,bb], numpy.asarray([0.25, 8.0]))
check_eq2_both(self, [a,b], mul(a,b), [bb,aa], numpy.asarray([0.25, 8.0])) # check_eq2_both(self, [a,b], mul(a,b), [bb,aa], numpy.asarray([0.25, 8.0]))
def test_scalar(self): # def test_scalar(self):
r = numpy.random.rand(2,3) # r = numpy.random.rand(2,3)
a = astensor(r) # a = astensor(r)
b = astensor(2.0) # b = astensor(2.0)
check_eq2_both(self, [a,b], mul(a,b), [r, 2.0], r*2.0) # check_eq2_both(self, [a,b], mul(a,b), [r, 2.0], r*2.0)
check_eq2_both(self, [a,b], mul(a,b), [r, 4.0], r*4.0) # check_eq2_both(self, [a,b], mul(a,b), [r, 4.0], r*4.0)
self.failUnless(b.data == 2.0) # self.failUnless(b.data == 2.0)
def test_rowcol(self): # def test_rowcol(self):
r1 = numpy.random.rand(3,5) # r1 = numpy.random.rand(3,5)
r2 = numpy.random.rand(1,5) # r2 = numpy.random.rand(1,5)
r3 = numpy.random.rand(3,1) # r3 = numpy.random.rand(3,1)
a1, a2, a3 = astensor(r1), astensor(r2), astensor(r3) # a1, a2, a3 = astensor(r1), astensor(r2), astensor(r3)
check_eq2_both(self, [a1,a2], mul(a1,a2), [r1, r2], r1*r2) # check_eq2_both(self, [a1,a2], mul(a1,a2), [r1, r2], r1*r2)
check_eq2_both(self, [a1,a3], mul(a1,a3), [r1, r3], r1*r3) # check_eq2_both(self, [a1,a3], mul(a1,a3), [r1, r3], r1*r3)
def test_grad_elemwise(self): # def test_grad_elemwise(self):
verify_grad(self, Mul, [numpy.random.rand(3,4), numpy.random.rand(3,4)]) # verify_grad(self, Mul, [numpy.random.rand(3,4), numpy.random.rand(3,4)])
def test_grad_scalar_l(self): # def test_grad_scalar_l(self):
verify_grad(self, Mul, [numpy.asarray([3.0]), numpy.random.rand(3)]) # verify_grad(self, Mul, [numpy.asarray([3.0]), numpy.random.rand(3)])
def test_grad_scalar_r(self): # def test_grad_scalar_r(self):
verify_grad(self, Mul, [numpy.random.rand(3), numpy.asarray([3.0])]) # verify_grad(self, Mul, [numpy.random.rand(3), numpy.asarray([3.0])])
def test_grad_row(self): # def test_grad_row(self):
verify_grad(self, Mul, [numpy.random.rand(3, 5), numpy.random.rand(1, 5)]) # verify_grad(self, Mul, [numpy.random.rand(3, 5), numpy.random.rand(1, 5)])
def test_grad_row2(self): # def test_grad_row2(self):
op = lambda x, y: Mul(x, DimShuffle(y, ['x', 0]).out) # op = lambda x, y: Mul(x, DimShuffle(y, ['x', 0]).out)
verify_grad(self, op, [numpy.random.rand(3, 5), numpy.random.rand(5)]) # verify_grad(self, op, [numpy.random.rand(3, 5), numpy.random.rand(5)])
def test_grad_col(self): # def test_grad_col(self):
verify_grad(self, Mul, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)]) # verify_grad(self, Mul, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
def test_wrong_shapes(self): # def test_wrong_shapes(self):
a = astensor(numpy.ones(3)) # a = astensor(numpy.ones(3))
b = astensor(numpy.ones(4)) # b = astensor(numpy.ones(4))
try: # try:
check_eq2(self, [a,b], Mul(a,b).out, # check_eq2(self, [a,b], Mul(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0) # [numpy.ones(3), numpy.ones(4)], 1.0)
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless('shape mismatch' in str(e)) # self.failUnless('shape mismatch' in str(e))
try: # try:
check_eq2_c(self, [a,b], Mul(a,b).out, # check_eq2_c(self, [a,b], Mul(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0) # [numpy.ones(3), numpy.ones(4)], 1.0)
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
pass # pass
class T_div(unittest.TestCase): # class T_div(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed(9999) # numpy.random.seed(9999)
def test_grad_e(self): # def test_grad_e(self):
verify_grad(self, Div, [numpy.random.rand(3), numpy.ones(3)]) # verify_grad(self, Div, [numpy.random.rand(3), numpy.ones(3)])
verify_grad(self, Div, [numpy.random.rand(3,5), numpy.random.rand(3,5)+0.1]) # verify_grad(self, Div, [numpy.random.rand(3,5), numpy.random.rand(3,5)+0.1])
verify_grad(self, Div, [numpy.ones(()), numpy.ones(())]) # verify_grad(self, Div, [numpy.ones(()), numpy.ones(())])
def test_grad_sl(self): # def test_grad_sl(self):
verify_grad(self, Div, [numpy.ones((3, 5)), numpy.ones((1, 1))]) # verify_grad(self, Div, [numpy.ones((3, 5)), numpy.ones((1, 1))])
verify_grad(self, Div, [numpy.random.rand(3), numpy.ones((1, ))]) # verify_grad(self, Div, [numpy.random.rand(3), numpy.ones((1, ))])
verify_grad(self, Div, [numpy.random.rand(3,5), numpy.random.rand(1,1)]) # verify_grad(self, Div, [numpy.random.rand(3,5), numpy.random.rand(1,1)])
class T_log2(unittest.TestCase): # class T_log2(unittest.TestCase):
def test0(self): # def test0(self):
verify_grad(self, Log2, [numpy.random.rand(3,1)+0.0001]) # verify_grad(self, Log2, [numpy.random.rand(3,1)+0.0001])
class T_log(unittest.TestCase): # class T_log(unittest.TestCase):
def test0(self): # def test0(self):
verify_grad(self, Log, [numpy.random.rand(3,1)+0.0001]) # verify_grad(self, Log, [numpy.random.rand(3,1)+0.0001])
def test1(self): # def test1(self):
a = astensor(numpy.ones(2)) # a = astensor(numpy.ones(2))
b = astensor(numpy.ones(2)) # b = astensor(numpy.ones(2))
aa = numpy.asarray([0.5, 4.0]) # aa = numpy.asarray([0.5, 4.0])
bb = numpy.asarray([0.5, 2.0]) # bb = numpy.asarray([0.5, 2.0])
check_eq2(self, [a], log(a), [aa], numpy.log(numpy.asarray(aa))) # check_eq2(self, [a], log(a), [aa], numpy.log(numpy.asarray(aa)))
class T_pow(unittest.TestCase): # class T_pow(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed(9999) # numpy.random.seed(9999)
def test_elemwise(self): # def test_elemwise(self):
verify_grad(self, Div, [numpy.random.rand(3,4), numpy.random.rand(3,4)+0.1]) # verify_grad(self, Div, [numpy.random.rand(3,4), numpy.random.rand(3,4)+0.1])
verify_grad(self, Pow, [numpy.random.rand(3,4), numpy.random.rand(3,4)]) # verify_grad(self, Pow, [numpy.random.rand(3,4), numpy.random.rand(3,4)])
def test_scalar_l(self): # def test_scalar_l(self):
verify_grad(self, Pow, [numpy.asarray([3.0]), numpy.random.rand(3)]) # verify_grad(self, Pow, [numpy.asarray([3.0]), numpy.random.rand(3)])
def test_scalar_r(self): # def test_scalar_r(self):
verify_grad(self, Pow, [numpy.random.rand(3), numpy.asarray([3.0])]) # verify_grad(self, Pow, [numpy.random.rand(3), numpy.asarray([3.0])])
def test_row(self): # def test_row(self):
verify_grad(self, Pow, [numpy.random.rand(3, 5), numpy.random.rand(1, 5)]) # verify_grad(self, Pow, [numpy.random.rand(3, 5), numpy.random.rand(1, 5)])
def test_col(self): # def test_col(self):
verify_grad(self, Pow, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)]) # verify_grad(self, Pow, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
class _testCase_matinv(unittest.TestCase): # class _testCase_matinv(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed(1) # numpy.random.seed(1)
def mat_reciprocal(self,dim): # def mat_reciprocal(self,dim):
# symbolic program # # symbolic program
# broadcastable=[False,False] means that the shape of matrix is two dimensional, # # broadcastable=[False,False] means that the shape of matrix is two dimensional,
# and none of the dimensions are constrained to have length 1. # # and none of the dimensions are constrained to have length 1.
# Note that Tensor's constructor does not actually allocate any memory. # # Note that Tensor's constructor does not actually allocate any memory.
# TODO: Make Tensor syntax more explicit, and maybe give shape or number of dimensions. # # TODO: Make Tensor syntax more explicit, and maybe give shape or number of dimensions.
a = Tensor('float64', broadcastable=[False,False], name='a') # a = Tensor('float64', broadcastable=[False,False], name='a')
b = Tensor('float64', broadcastable=[False,False], name='b') # b = Tensor('float64', broadcastable=[False,False], name='b')
ab = a*b # ab = a*b
# Here, astensor actually uses the data allocated by numpy. # # Here, astensor actually uses the data allocated by numpy.
diff = ab - astensor(numpy.ones((dim,dim))) # diff = ab - astensor(numpy.ones((dim,dim)))
# Sum of squared errors # # Sum of squared errors
ssdiff = sum((diff**2.0)) # ssdiff = sum((diff**2.0))
g_b = gradient.grad(ssdiff, b) # g_b = gradient.grad(ssdiff, b)
# compilation to function # # compilation to function
# [a,b] are the inputs, [ssdiff,g_b] are the outputs # # [a,b] are the inputs, [ssdiff,g_b] are the outputs
fn = Function([a,b], [ssdiff,g_b]) # fn = Function([a,b], [ssdiff,g_b])
# use the function # # use the function
x = numpy.random.rand(dim,dim)+0.1 # Initialized s.t. x is not too tiny # x = numpy.random.rand(dim,dim)+0.1 # Initialized s.t. x is not too tiny
w = numpy.random.rand(dim,dim) # w = numpy.random.rand(dim,dim)
for i in xrange(300): # for i in xrange(300):
ssd, gw = fn(x,w) # ssd, gw = fn(x,w)
#print ssd, x*w, x, w # #print ssd, x*w, x, w
if i == 0: # if i == 0:
str0 = str(ssd) # str0 = str(ssd)
w -= 0.4 * gw # w -= 0.4 * gw
return str0, str(ssd) # return str0, str(ssd)
def test_reciprocal(self): # def test_reciprocal(self):
"""Matrix reciprocal by gradient descent""" # """Matrix reciprocal by gradient descent"""
self.assertEqual(('6.10141615619', '0.00703816291711'), self.mat_reciprocal(3)) # self.assertEqual(('6.10141615619', '0.00703816291711'), self.mat_reciprocal(3))
class t_dot(unittest.TestCase): # class t_dot(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed(44) # numpy.random.seed(44)
@staticmethod # @staticmethod
def rand(*args): # def rand(*args):
return numpy.random.rand(*args) # return numpy.random.rand(*args)
def cmp_dot(self,x,y): # def cmp_dot(self,x,y):
#x, y are matrices or numbers # #x, y are matrices or numbers
def spec(x): # def spec(x):
x = numpy.asarray(x) # x = numpy.asarray(x)
return type(x), x.dtype, x.shape # return type(x), x.dtype, x.shape
nz = numpy.dot(x,y) # nz = numpy.dot(x,y)
tz = eval_outputs([dot(astensor(x), astensor(y))]) # tz = eval_outputs([dot(astensor(x), astensor(y))])
self.failUnless(tz.dtype == nz.dtype) # self.failUnless(tz.dtype == nz.dtype)
self.failUnless(tz.shape == nz.shape) # self.failUnless(tz.shape == nz.shape)
self.failUnless(_approx_eq(nz, tz)) # self.failUnless(_approx_eq(nz, tz))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2) # def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5)) # def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7)) # def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7)) # def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 ) # def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5)) # def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7)) # def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7)) # def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0) # def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6)) # def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7)) # def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7)) # def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0) # def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6)) # def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7)) # def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7)) # def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def not_aligned(self, x, y): # def not_aligned(self, x, y):
z = dot(x,y) # z = dot(x,y)
try: # try:
tz = eval_outputs([z]) # tz = eval_outputs([z])
except ValueError, e: # except ValueError, e:
self.failUnless(e[0].split()[1:4] == ['are', 'not', 'aligned'], e) # self.failUnless(e[0].split()[1:4] == ['are', 'not', 'aligned'], e)
return # return
self.fail() # self.fail()
def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6)) # def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6))
def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4)) # def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4))
def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7)) # def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6)) # def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7)) # def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7))
def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8)) # def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8))
def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6)) # def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6))
def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7)) # def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8)) # def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
def test_grad(self): # def test_grad(self):
verify_grad(self, Dot, [self.rand(2,3), self.rand(3,2)]) # verify_grad(self, Dot, [self.rand(2,3), self.rand(3,2)])
class t_gemm(unittest.TestCase): # class t_gemm(unittest.TestCase):
def setUp(self): # def setUp(self):
numpy.random.seed(44) # numpy.random.seed(44)
_approx_eq.debug = 0 # _approx_eq.debug = 0
Gemm.debug = False # Gemm.debug = False
@staticmethod # @staticmethod
def _gemm(z,a,x,y,b): # def _gemm(z,a,x,y,b):
assert a.shape == () # assert a.shape == ()
assert b.shape == () # assert b.shape == ()
return b * z + a * numpy.dot(x,y) # return b * z + a * numpy.dot(x,y)
@staticmethod # @staticmethod
def rand(*args): # def rand(*args):
return numpy.random.rand(*args) # return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b): # def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l): # def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b] # z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy() # z_orig = z.copy()
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b] # tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l) # f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l)
new_z = f(z,a,x,y,b) # new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b) # z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z) # self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z) # #print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1 # #_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z)) # self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0: # if a == 0.0 and b == 1.0:
return # return
else: # else:
self.failIf(numpy.all(z_orig == z)) # self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker) # cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker)
#cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker) # #cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker) # cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker)
def test0a(self): # def test0a(self):
Gemm.debug = True # Gemm.debug = True
try: # try:
g = gemm([1.], 1., [1.], [1.], 1.) # g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e: # except ValueError, e:
if e[0] is Gemm.E_rank: # if e[0] is Gemm.E_rank:
return # return
self.fail() # self.fail()
def test0(self): # def test0(self):
try: # try:
self.cmp(1., 0., 1.0, 1.0, 1.0) # self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e: # except ValueError, e:
if e[0] is Gemm.E_rank: # if e[0] is Gemm.E_rank:
return # return
self.fail() # self.fail()
def test2(self): # def test2(self):
try: # try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0) # self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank) # self.failUnless(e[0] == Gemm.E_rank)
return # return
self.fail() # self.fail()
def test4(self): # def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0) # self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0, # def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0) # self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0, # def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0) # self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0, # def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0) # self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0, # def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6) # self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0, # def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0) # self.rand(3,5), self.rand(5,4), -1.0)
def test10(self): # def test10(self):
_approx_eq.debug = 1 # _approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0) # self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0, # def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0) # self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0, # def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0) # self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self): # def test_destroy_map0(self):
"""test that only first input can be overwritten""" # """test that only first input can be overwritten"""
Z = astensor(self.rand(2,2)) # Z = astensor(self.rand(2,2))
try: # try:
gemm(Z, 1.0, Z, Z, 1.0) # gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e: # except ValueError, e:
if e[0] == Gemm.E_z_uniq: # if e[0] == Gemm.E_z_uniq:
return # return
self.fail() # self.fail()
def test_destroy_map1(self): # def test_destroy_map1(self):
"""test that only first input can be overwritten""" # """test that only first input can be overwritten"""
Z = astensor(self.rand(2,2)) # Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2)) # A = astensor(self.rand(2,2))
try: # try:
gemm(Z, 1.0, A, transpose_inplace(Z), 1.0) # gemm(Z, 1.0, A, transpose_inplace(Z), 1.0)
except ValueError, e: # except ValueError, e:
if e[0] == Gemm.E_z_uniq: # if e[0] == Gemm.E_z_uniq:
return # return
self.fail() # self.fail()
def test_destroy_map2(self): # def test_destroy_map2(self):
"""test that only first input can be overwritten""" # """test that only first input can be overwritten"""
Z = astensor(self.rand(2,2)) # Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2)) # A = astensor(self.rand(2,2))
try: # try:
gemm(Z, 1.0, transpose_inplace(Z), A, 1.0) # gemm(Z, 1.0, transpose_inplace(Z), A, 1.0)
except ValueError, e: # except ValueError, e:
if e[0] == Gemm.E_z_uniq: # if e[0] == Gemm.E_z_uniq:
return # return
self.fail() # self.fail()
def test_destroy_map3(self): # def test_destroy_map3(self):
"""test that only first input can be overwritten""" # """test that only first input can be overwritten"""
Z = astensor(self.rand(2,2)) # Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2)) # A = astensor(self.rand(2,2))
try: # try:
gemm(Z, 1.0, Z, A, 1.0) # gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e: # except ValueError, e:
if e[0] == Gemm.E_z_uniq: # if e[0] == Gemm.E_z_uniq:
return # return
self.fail() # self.fail()
def test_destroy_map4(self): # def test_destroy_map4(self):
"""test that dot args can be aliased""" # """test that dot args can be aliased"""
Z = astensor(self.rand(2,2)) # Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2)) # A = astensor(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)]) # eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)]) # eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self): # def test_transposes(self):
# three square matrices which are not contiguous # # three square matrices which are not contiguous
A = self.rand(4,5)[:,:4] # A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4] # B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4] # C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker,dt='float64'): # def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker,dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b] # z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy() # z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b) # z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b] # tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l) # f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l)
f(z, a, x, y, b) # f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z)) # self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z))
f(z.T, a, y.T, x.T, b) # f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z)) # self.failUnless(_approx_eq(z_after, z))
t(C,A,B) # t(C,A,B)
t(C.T, A, B) # t(C.T, A, B)
t(C, A.T, B, dt='float32') # t(C, A.T, B, dt='float32')
t(C, A, B.T) # t(C, A, B.T)
t(C.T, A.T, B) # t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32') # t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T) # t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32') # t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :]) # t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32') # t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :]) # t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32') # t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T) # t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T) # t(C.T, A[:2,:].T, B[:, :2].T)
try: # try:
t(C.T, A[:2,:], B[:, :2].T) # t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e: # except ValueError, e:
if e[0].find('aligned') >= 0: # if e[0].find('aligned') >= 0:
return # return
self.fail() # self.fail()
def _tensor(data, broadcastable=None, name=None): # def _tensor(data, broadcastable=None, name=None):
"""Return a Tensor containing given data""" # """Return a Tensor containing given data"""
data = numpy.asarray(data) # data = numpy.asarray(data)
if broadcastable is None: # if broadcastable is None:
broadcastable = [s==1 for s in data.shape] # broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]: # elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape) # broadcastable = [broadcastable] * len(data.shape)
rval = Tensor(data.dtype, broadcastable, name) # rval = Tensor(data.dtype, broadcastable, name)
rval.data = data # will raise if broadcastable was mis-specified # rval.data = data # will raise if broadcastable was mis-specified
return rval # return rval
class T_tensor(unittest.TestCase): # class T_tensor(unittest.TestCase):
def test0(self): # allocate from a scalar float # def test0(self): # allocate from a scalar float
t = _tensor(1.0) # t = _tensor(1.0)
self.failUnless(isinstance(t, Tensor)) # self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'float64') # self.failUnless(t.dtype == 'float64')
self.failUnless(t.broadcastable == ()) # self.failUnless(t.broadcastable == ())
self.failUnless(t.role == None) # self.failUnless(t.role == None)
self.failUnless(isinstance(t.data, numpy.ndarray)) # self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(str(t.data.dtype) == 'float64') # self.failUnless(str(t.data.dtype) == 'float64')
self.failUnless(t.data == 1.0) # self.failUnless(t.data == 1.0)
def test0_int(self): # allocate from a scalar float # def test0_int(self): # allocate from a scalar float
t = _tensor(1) # t = _tensor(1)
self.failUnless(isinstance(t, Tensor)) # self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'int64' or t.dtype == 'int32') # self.failUnless(t.dtype == 'int64' or t.dtype == 'int32')
def test1(self): # allocate from a vector of ints, not broadcastable # def test1(self): # allocate from a vector of ints, not broadcastable
t = _tensor(numpy.ones(5,dtype='int32')) # t = _tensor(numpy.ones(5,dtype='int32'))
self.failUnless(isinstance(t, Tensor)) # self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'int32') # self.failUnless(t.dtype == 'int32')
self.failUnless(t.broadcastable == (0,)) # self.failUnless(t.broadcastable == (0,))
self.failUnless(isinstance(t.data, numpy.ndarray)) # self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(str(t.data.dtype) == 'int32') # self.failUnless(str(t.data.dtype) == 'int32')
def test2(self): # allocate from a column matrix of complex with name # def test2(self): # allocate from a column matrix of complex with name
t = _tensor(numpy.ones((5,1),dtype='complex64'),name='bart') # t = _tensor(numpy.ones((5,1),dtype='complex64'),name='bart')
self.failUnless(isinstance(t, Tensor)) # self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'complex64') # self.failUnless(t.dtype == 'complex64')
self.failUnless(t.broadcastable == (0,1)) # self.failUnless(t.broadcastable == (0,1))
self.failUnless(isinstance(t.data, numpy.ndarray)) # self.failUnless(isinstance(t.data, numpy.ndarray))
self.failUnless(t.name == 'bart') # self.failUnless(t.name == 'bart')
def test2b(self): # allocate from a column matrix, not broadcastable # def test2b(self): # allocate from a column matrix, not broadcastable
t = _tensor(numpy.ones((5,1),dtype='complex64'),broadcastable=0) # t = _tensor(numpy.ones((5,1),dtype='complex64'),broadcastable=0)
self.failUnless(isinstance(t, Tensor)) # self.failUnless(isinstance(t, Tensor))
self.failUnless(t.dtype == 'complex64') # self.failUnless(t.dtype == 'complex64')
self.failUnless(t.broadcastable == (0,0)) # self.failUnless(t.broadcastable == (0,0))
self.failUnless(isinstance(t.data, numpy.ndarray)) # self.failUnless(isinstance(t.data, numpy.ndarray))
f = Function([t], [t], linker_cls=gof.CLinker) # f = Function([t], [t], linker_cls=gof.CLinker)
self.failUnless(numpy.all(t.data == f(t.data))) # self.failUnless(numpy.all(t.data == f(t.data)))
def test_data_normal(self): #test that assigning to .data works when it should # def test_data_normal(self): #test that assigning to .data works when it should
t = _tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0) # t = _tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0)
o27 = numpy.ones((2,7), dtype='complex64') # o27 = numpy.ones((2,7), dtype='complex64')
t.data = o27 # t.data = o27
lst = t._data # lst = t._data
self.failUnless(t.data.shape == (2,7)) # self.failUnless(t.data.shape == (2,7))
self.failUnless(t.data is o27) # self.failUnless(t.data is o27)
self.failUnless(t._data is lst) # self.failUnless(t._data is lst)
def test_data_badrank0(self): # def test_data_badrank0(self):
t = _tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0) # t = _tensor(numpy.ones((5,1),dtype='complex64'), broadcastable=0)
try: # try:
t.data = numpy.ones((2,7,1)) # t.data = numpy.ones((2,7,1))
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank) # self.failUnless(e[0] is Tensor.filter.E_rank)
try: # try:
t.data = numpy.ones(1) # t.data = numpy.ones(1)
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank) # self.failUnless(e[0] is Tensor.filter.E_rank)
def test_data_badrank1(self): # def test_data_badrank1(self):
t = _tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1) # t = _tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1)
try: # try:
t.data = numpy.ones((1,1,1)) # t.data = numpy.ones((1,1,1))
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank) # self.failUnless(e[0] is Tensor.filter.E_rank)
try: # try:
t.data = numpy.ones(1) # t.data = numpy.ones(1)
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_rank) # self.failUnless(e[0] is Tensor.filter.E_rank)
def test_data_badshape0(self): # def test_data_badshape0(self):
t = _tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1) # t = _tensor(numpy.ones((1,1),dtype='complex64'), broadcastable=1)
try: # try:
t.data = numpy.ones((1,2)) # t.data = numpy.ones((1,2))
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_shape) # self.failUnless(e[0] is Tensor.filter.E_shape)
try: # try:
t.data = numpy.ones((0,1)) # t.data = numpy.ones((0,1))
self.fail() # self.fail()
except ValueError, e: # except ValueError, e:
self.failUnless(e[0] is Tensor.filter.E_shape) # self.failUnless(e[0] is Tensor.filter.E_shape)
def test_cast0(self): # def test_cast0(self):
t = Tensor('float32', [0]) # t = Tensor('float32', [0])
t.data = numpy.random.rand(4) > 0.5 # t.data = numpy.random.rand(4) > 0.5
self.failUnless(str(t.data.dtype) == t.dtype) # self.failUnless(str(t.data.dtype) == t.dtype)
class T_stdlib(unittest.TestCase): # class T_stdlib(unittest.TestCase):
def test0(self): # def test0(self):
t = _tensor(1.0) # t = _tensor(1.0)
tt = t.clone(False) # tt = t.clone(False)
self.failUnless(t.dtype == tt.dtype) # self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable) # self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data is None) # self.failUnless(tt.data is None)
self.failUnless(t.data == 1.0) # self.failUnless(t.data == 1.0)
def test0b(self): # def test0b(self):
t = _tensor(1.0) # t = _tensor(1.0)
tt = t.clone() # tt = t.clone()
self.failUnless(t.dtype == tt.dtype) # self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable) # self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data is None) # self.failUnless(tt.data is None)
self.failUnless(t.data == 1.0) # self.failUnless(t.data == 1.0)
def test1(self): # def test1(self):
t = _tensor(1.0) # t = _tensor(1.0)
tt = t.clone(True) # tt = t.clone(True)
self.failUnless(t.dtype == tt.dtype) # self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable) # self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data == 1.0) # self.failUnless(tt.data == 1.0)
self.failUnless(t.data == 1.0) # self.failUnless(t.data == 1.0)
self.failUnless(t.data is not tt.data) # self.failUnless(t.data is not tt.data)
def test1b(self): # def test1b(self):
t = _tensor(1.0) # t = _tensor(1.0)
tt = copy(t) # tt = copy(t)
self.failUnless(t.dtype == tt.dtype) # self.failUnless(t.dtype == tt.dtype)
self.failUnless(t.broadcastable is tt.broadcastable) # self.failUnless(t.broadcastable is tt.broadcastable)
self.failUnless(tt.data == 1.0) # self.failUnless(tt.data == 1.0)
self.failUnless(t.data == 1.0) # self.failUnless(t.data == 1.0)
self.failUnless(t.data is not tt.data) # self.failUnless(t.data is not tt.data)
......
...@@ -16,10 +16,10 @@ def exec_opt(inputs, outputs, features=[]): ...@@ -16,10 +16,10 @@ def exec_opt(inputs, outputs, features=[]):
exec_opt.optimizer = None exec_opt.optimizer = None
class _DefaultOptimizer(object): class _DefaultOptimizer(object):
const = gof.opt.ConstantFinder() #const = gof.opt.ConstantFinder()
merge = gof.opt.MergeOptimizer() merge = gof.opt.MergeOptimizer()
def __call__(self, env): def __call__(self, env):
self.const(env) #self.const(env)
self.merge(env) self.merge(env)
default_optimizer = _DefaultOptimizer() default_optimizer = _DefaultOptimizer()
...@@ -31,7 +31,7 @@ def linker_cls_python_and_c(env): ...@@ -31,7 +31,7 @@ def linker_cls_python_and_c(env):
"""Use this as the linker_cls argument to Function.__init__ to compare """Use this as the linker_cls argument to Function.__init__ to compare
python and C implementations""" python and C implementations"""
def checker(x, y): def checker(x, y):
x, y = x.data, y.data x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
...@@ -105,17 +105,23 @@ class Function: ...@@ -105,17 +105,23 @@ class Function:
#print 'orphans', orphans #print 'orphans', orphans
#print 'ops', gof.graph.ops(inputs, outputs) #print 'ops', gof.graph.ops(inputs, outputs)
env = gof.env.Env(inputs, outputs, features + [gof.EquivTool], consistency_check = True) env = gof.env.Env(inputs, outputs)
#print 'orphans in env', env.orphans() #print 'orphans in env', env.orphans()
env = env.clone(clone_inputs=True) env, equiv = env.clone_get_equiv(clone_inputs=True)
for feature in features:
env.extend(feature(env))
env.extend(gof.DestroyHandler(env))
#print 'orphans after clone', env.orphans() #print 'orphans after clone', env.orphans()
for d, o in zip(orphan_data, [env.equiv(orphan) for orphan in orphans]): for d, o in zip(orphan_data, [equiv[orphan] for orphan in orphans]):
#print 'assigning orphan value', d #print 'assigning orphan value', d
o.data = d #o.data = d
new_o = gof.Constant(o.type, d)
env.replace(o, new_o)
assert new_o in env.orphans
# optimize and link the cloned env # optimize and link the cloned env
if None is not optimizer: if None is not optimizer:
...@@ -127,11 +133,9 @@ class Function: ...@@ -127,11 +133,9 @@ class Function:
self.__dict__.update(locals()) self.__dict__.update(locals())
if profiler is None: if profiler is None:
self.fn = linker.make_function(inplace=True, self.fn = linker.make_function(unpack_single=unpack_single)
unpack_single=unpack_single)
else: else:
self.fn = linker.make_function(inplace=True, self.fn = linker.make_function(unpack_single=unpack_single,
unpack_single=unpack_single,
profiler=profiler) profiler=profiler)
self.inputs = env.inputs self.inputs = env.inputs
self.outputs = env.outputs self.outputs = env.outputs
...@@ -146,16 +150,6 @@ class Function: ...@@ -146,16 +150,6 @@ class Function:
def __call__(self, *args): def __call__(self, *args):
return self.fn(*args) return self.fn(*args)
def __copy__(self):
return Function(self.inputs, self.outputs,
features = self.features,
optimizer = self.optimizer,
linker_cls = self.linker_cls,
profiler = self.profiler,
unpack_single = self.unpack_single,
except_unreachable_input = self.except_unreachable_input,
keep_locals = self.keep_locals)
def eval_outputs(outputs, def eval_outputs(outputs,
features = [], features = [],
...@@ -171,20 +165,23 @@ def eval_outputs(outputs, ...@@ -171,20 +165,23 @@ def eval_outputs(outputs,
else: else:
return [] return []
inputs = list(gof.graph.inputs(outputs)) inputs = gof.graph.inputs(outputs)
in_data = [i.data for i in inputs if i.data is not None] if any(not isinstance(input, gof.Constant) for input in inputs):
raise TypeError("Cannot evaluate outputs because some of the leaves are not Constant.", outputs)
in_data = [i.data for i in inputs]
#print 'in_data = ', in_data #print 'in_data = ', in_data
if len(inputs) != len(in_data): if len(inputs) != len(in_data):
raise Exception('some input data is unknown') raise Exception('some input data is unknown')
env = gof.env.Env(inputs, outputs, features, consistency_check = True) env = gof.env.Env(inputs, outputs)
env.replace_all(dict([(i, i.type()) for i in inputs]))
env = env.clone(clone_inputs=True) env = env.clone(clone_inputs=True)
_mark_indestructible(env.outputs) _mark_indestructible(env.outputs)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
linker = linker_cls(env) linker = linker_cls(env)
fn = linker.make_function(inplace=True, unpack_single=unpack_single) fn = linker.make_function(unpack_single=unpack_single)
rval = fn(*in_data) rval = fn(*in_data)
return rval return rval
......
...@@ -2,29 +2,34 @@ ...@@ -2,29 +2,34 @@
import elemwise_cgen as cgen import elemwise_cgen as cgen
import numpy import numpy
from gof import Op, Viewer, Destroyer from gof import Op, Apply
import scalar import scalar
from scalar import upcast, Scalar from scalar import Scalar
import gof import gof
from gof.python25 import all from gof.python25 import all
def astensor(data): def as_tensor(data):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
def Tensor(*inputs, **kwargs): def Tensor(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
def TensorResult(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
################## ##################
### DimShuffle ### ### DimShuffle ###
################## ##################
class DimShuffle(Op, Viewer): class DimShuffle(Op):
""" """
Usage: DimShuffle(input, new_order, inplace = True) Usage: DimShuffle(new_order, inplace = True)
* input: a Tensor instance
* new_order: a list representing the relationship between the * new_order: a list representing the relationship between the
input's dimensions and the output's dimensions. Each input's dimensions and the output's dimensions. Each
element of the list can either be an index or 'x'. element of the list can either be an index or 'x'.
...@@ -51,33 +56,18 @@ class DimShuffle(Op, Viewer): ...@@ -51,33 +56,18 @@ class DimShuffle(Op, Viewer):
DimShuffle(t2, [1, 'x', 0]) -> like doing t3.T.reshape((t3.shape[0], 1, t3.shape[1])) in numpy DimShuffle(t2, [1, 'x', 0]) -> like doing t3.T.reshape((t3.shape[0], 1, t3.shape[1])) in numpy
""" """
def __init__(self, input, new_order, inplace = True): def __init__(self, input_broadcastable, new_order, inplace = True):
input_broadcastable = tuple(input_broadcastable)
input = astensor(input) self.input_broadcastable = input_broadcastable
new_order = tuple(new_order)
ib = input.broadcastable
ob = []
for value in new_order:
if value == 'x':
self.has_x = True
ob.append(1)
else:
ob.append(ib[value])
output = Tensor(dtype = input.dtype,
broadcastable = ob)
self.new_order = new_order self.new_order = new_order
self.inputs = input,
self.outputs = output,
self.inplace = inplace self.inplace = inplace
# list of dimensions of the input to drop # list of dimensions of the input to drop
self.drop = [] self.drop = []
i2j = {} # this maps i before dropping dimensions to j after dropping dimensions so self.shuffle can be set properly later on i2j = {} # this maps i before dropping dimensions to j after dropping dimensions so self.shuffle can be set properly later on
j = 0 j = 0
for i, b in enumerate(ib): for i, b in enumerate(input_broadcastable):
if i not in new_order: if i not in new_order:
# we want to drop this dimension because it's not a value in new_order # we want to drop this dimension because it's not a value in new_order
if b == 1: if b == 1:
...@@ -95,24 +85,39 @@ class DimShuffle(Op, Viewer): ...@@ -95,24 +85,39 @@ class DimShuffle(Op, Viewer):
# list of dimensions of the output that are broadcastable and were not in the original input # list of dimensions of the output that are broadcastable and were not in the original input
self.augment = [i for i, x in enumerate(new_order) if x == 'x'] self.augment = [i for i, x in enumerate(new_order) if x == 'x']
def clone_with_new_inputs(self, *new_inputs):
return DimShuffle(new_inputs[0], self.new_order, self.inplace)
def view_map(self):
if self.inplace: if self.inplace:
return {self.outputs[0]: [self.inputs[0]]} self.view_map = {0: [0]}
else:
return {} def make_node(self, input):
ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable:
raise TypeError("The number of dimensions and/or broadcastable pattern of the input is incorrect for this op. Expected %s, got %s." % (ib, self.input_broadcastable))
ob = []
for value in self.new_order:
if value == 'x':
ob.append(1)
else:
ob.append(ib[value])
output = Tensor(dtype = input.type.dtype,
broadcastable = ob).make_result()
return Apply(self, [input], [output])
def __eq__(self, other):
return type(self) == type(other) \
and self.inplace == other.inplace \
and self.new_order == other.new_order \
and self.input_broadcastable == other.input_broadcastable
def desc(self): def __hash__(self, other):
return (self.__class__, tuple(self.new_order)) return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable)
def strdesc(self): def __str__(self):
return "DimShuffle{%s}" % "".join(str(x) for x in self.new_order) return "DimShuffle{%s}" % "".join(str(x) for x in self.new_order)
def perform(self): def perform(self, node, (input, ), (storage, )):
# drop # drop
res = self.inputs[0].data res = input
shape = list(res.shape) shape = list(res.shape)
for drop in reversed(self.drop): for drop in reversed(self.drop):
shape.pop(drop) shape.pop(drop)
...@@ -131,33 +136,30 @@ class DimShuffle(Op, Viewer): ...@@ -131,33 +136,30 @@ class DimShuffle(Op, Viewer):
if not self.inplace: if not self.inplace:
res = numpy.copy(res) res = numpy.copy(res)
self.outputs[0].data = res storage[0] = res
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
grad_order = ['x'] * len(self.inputs[0].broadcastable) gz = as_tensor(gz)
for i, x in enumerate(self.new_order): grad_order = ['x'] * len(x.type.broadcastable)
if x != 'x': for i, v in enumerate(self.new_order):
grad_order[x] = i if v != 'x':
return DimShuffle(gz, grad_order).out, grad_order[v] = i
return DimShuffle(gz.type.broadcastable, grad_order)(gz),
def __str__(self):
return "%s(%s, %s)" % (self.__class__.__name__, str(self.inputs[0]), self.new_order)
################# ################
### Broadcast ### ### Elemwise ###
################# ################
class Broadcast(Op, Destroyer): class Elemwise(Op):
""" """
Generalizes a scalar op to tensors. Generalizes a scalar op to tensors.
Usage: Broadcast(scalar_opclass, inputs, inplace_pattern = {}) Usage: Elemwise(scalar_op, inplace_pattern = {})
* scalar_opclass: a class that extends scalar.ScalarOp, works uniquely on * scalar_op: an instance of a subclass of scalar.ScalarOp which works uniquely on
scalars and can be instantiated from the list of its inputs scalars
* inputs: a list of Tensor instances
* inplace_pattern: a dictionary that maps the index of an output to the * inplace_pattern: a dictionary that maps the index of an output to the
index of an input so the output is calculated inplace using index of an input so the output is calculated inplace using
the input's storage. the input's storage.
...@@ -175,94 +177,86 @@ class Broadcast(Op, Destroyer): ...@@ -175,94 +177,86 @@ class Broadcast(Op, Destroyer):
as the input (in a nutshell, int + float -> float but int += float -> int) as the input (in a nutshell, int + float -> float but int += float -> int)
Examples: Examples:
Broadcast(Add, rand(10, 5), rand(10, 5), {0 : 0}) # this does input0 += input1 Elemwise(add) # represents + on tensors (x + y)
Broadcast(Add, rand(10, 5), rand(10, 5), {0 : 1}) # this does input1 += input0 Elemwise(add, {0 : 0}) # represents the += operation (x += y)
Broadcast(Mul, rand(10, 5), rand(1, 5)) # the second input is completed along the first dimension to match the first input Elemwise(add, {0 : 1}) # represents += on the second argument (y += x)
Broadcast(Div, rand(10, 5), rand(10, 1)) # same but along the second dimension Elemwise(mul)(rand(10, 5), rand(1, 5)) # the second input is completed along the first dimension to match the first input
Broadcast(Div, rand(1, 5), rand(10, 1)) # the output has size (10, 5) Elemwise(div)(rand(10, 5), rand(10, 1)) # same but along the second dimension
Broadcast(Log, rand(3, 4, 5)) Elemwise(div)(rand(1, 5), rand(10, 1)) # the output has size (10, 5)
Elemwise(log)(rand(3, 4, 5))
""" """
def __init__(self, scalar_opclass, inputs, inplace_pattern = {}): def __init__(self, scalar_op, inplace_pattern = {}):
self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items())
if scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, scalar_op.nout)
else:
self.ufunc = None
inputs = map(astensor, inputs) def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
try: shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
assert len(set([len(input.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__)
# self.shadow is an instance of scalar_opclass used to get values for all the properties we need (dtypes, gradient, etc.) target_length = max([input.type.ndim for input in inputs])
self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in inputs]) args = []
for input in inputs:
self.nin = self.shadow.nin length = input.type.ndim
self.nout = self.shadow.nout difference = target_length - length
out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout if not difference:
args.append(input)
else:
args.append(DimShuffle(range(length), ['x']*difference + range(length))(input))
inputs = args
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout
inplace_pattern = self.inplace_pattern
if inplace_pattern: if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items(): for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip(out_broadcastables[overwriter], inputs[overwritten].broadcastable): for ob, ib in zip(out_broadcastables[overwriter], inputs[overwritten].type.broadcastable):
if ib and not ob: if ib and not ob:
raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.") raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.")
out_dtypes = [o.type.dtype for o in shadow.outputs]
if any(inputs[i].type.dtype != out_dtypes[o] for i, o in inplace_pattern.items()):
raise TypeError("Cannot do an inplace operation on incompatible data types.", [i.type.dtype for i in inputs], out_dtypes)
outputs = [Tensor(dtype = dtype, broadcastable = broadcastable)() for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
return Apply(self, inputs, outputs)
out_dtypes = [t.dtype for t in self.shadow.outputs] def __str__(self):
def get_dtype(i):
# If an operation is done inplace, the dtype of the output
# will be the same as the dtype of the input it overwrites
# eg int + float -> float, but int += float -> int
input_idx = inplace_pattern.get(i, None)
if input_idx is not None:
return inputs[input_idx].dtype
else:
return out_dtypes[i]
out_dtypes = map(get_dtype, xrange(self.nout))
self.inputs = inputs
self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
self.inplace_pattern = inplace_pattern
self.scalar_opclass = scalar_opclass
self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
def desc(self):
return (Broadcast, self.scalar_opclass, tuple(self.inplace_pattern.items()))
def strdesc(self):
if self.inplace_pattern: if self.inplace_pattern:
return "Broadcast{%s}%s" % (self.shadow.strdesc(), str(self.inplace_pattern)) return "Broadcast{%s}%s" % (self.scalar_op, str(self.inplace_pattern))
else: else:
return "Broadcast{%s}" % (self.shadow.strdesc()) return "Broadcast{%s}" % (self.scalar_op)
def destroy_map(self):
ret = {}
for key, value in self.inplace_pattern.items():
ret[self.outputs[key]] = [self.inputs[value]]
return ret
def grad(self, inputs, ograds): def grad(self, inputs, ograds):
ograds = map(astensor, ograds) ograds = map(as_tensor, ograds) # this shouldn't be necessary...
shadow = self.shadow scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs]
scalar_ograds = [Scalar(dtype = ograd.dtype) for ograd in ograds] scalar_ograds = [Scalar(dtype = ograd.type.dtype)() for ograd in ograds]
scalar_igrads = shadow.grad(shadow.inputs, scalar_ograds) scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
nd = len(inputs[0].broadcastable) # this is the same for everyone nd = len(inputs[0].type.broadcastable) # this is the same for everyone
def transform(r): def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops. # From a graph of ScalarOps, make a graph of Broadcast ops.
if r in shadow.inputs: if r in scalar_inputs:
return inputs[shadow.inputs.index(r)] return inputs[scalar_inputs.index(r)]
if r in scalar_ograds: if r in scalar_ograds:
return ograds[scalar_ograds.index(r)] return ograds[scalar_ograds.index(r)]
op = r.owner node = r.owner
if op is None: if node is None:
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent Tensor of size 1 and proper number of dimensions # an equivalent Tensor of size 1 and proper number of dimensions
b = [1] * nd b = [1] * nd
res = astensor(numpy.asarray(r.data).reshape(b), res = TensorConstant(Tensor(dtype = r.type.dtype,
broadcastable = b) broadcastable = b),
numpy.asarray(r.data).reshape(b))
return res return res
op_class = op.__class__ new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs])
bcasted = Broadcast(op_class, [transform(input) for input in op.inputs], {}).out return new_r
return bcasted
ret = [] ret = []
for scalar_igrad, input in zip(scalar_igrads, inputs): for scalar_igrad, input in zip(scalar_igrads, inputs):
if scalar_igrad is None: if scalar_igrad is None:
...@@ -274,91 +268,89 @@ class Broadcast(Op, Destroyer): ...@@ -274,91 +268,89 @@ class Broadcast(Op, Destroyer):
# list of all the dimensions that are broadcastable for that input so we # list of all the dimensions that are broadcastable for that input so we
# can sum over them # can sum over them
# todo: only count dimensions that were effectively broadcasted # todo: only count dimensions that were effectively broadcasted
to_sum = [i for i, bcast in enumerate(input.broadcastable) if bcast] to_sum = [i for i, bcast in enumerate(input.type.broadcastable) if bcast]
if to_sum: if to_sum:
shuffle = [] shuffle = []
j = 0 j = 0
for bcast in input.broadcastable: for bcast in input.type.broadcastable:
if bcast == 1: if bcast == 1:
shuffle.append('x') shuffle.append('x')
else: else:
shuffle.append(j) shuffle.append(j)
j += 1 j += 1
sr = Sum(r, axis = to_sum).out sr = Sum(axis = to_sum)(r)
sr = DimShuffle(sr, shuffle).out sr = DimShuffle(sr.type.broadcastable, shuffle)(sr)
ret.append(sr) ret.append(sr)
else: else:
ret.append(r) ret.append(r)
return ret return ret
def perform(self): def perform(self, node, inputs, output_storage):
output_storage = []
if not self.inplace_pattern: if not self.inplace_pattern:
for output in self.outputs: for output, storage in zip(node.outputs, output_storage):
odat = output.data odat = storage[0]
shape = [max(values) for values in zip(*[input.data.shape for input in self.inputs])] shape = [max(values) for values in zip(*[input.shape for input in inputs])]
if odat is not None: if odat is not None:
# reuse storage if we can # reuse storage if we can
odat.resize(shape, refcheck = 0) odat.resize(shape, refcheck = 0)
else: else:
odat = numpy.ndarray(shape, dtype = output.dtype) odat = numpy.ndarray(shape, dtype = output.type.dtype)
output_storage.append(odat) storage[0] = odat
output.data = odat
else: else:
for i, output in enumerate(self.outputs): for i, (output, storage) in enumerate(zip(node.outputs, output_storage)):
if i in self.inplace_pattern: if i in self.inplace_pattern:
odat = self.inputs[self.inplace_pattern[i]].data odat = inputs[self.inplace_pattern[i]]
else: else:
odat = output.data odat = storage[0]
shape = [max(values) for values in zip(*[input.data.shape for input in self.inputs])] shape = [max(values) for values in zip(*[input.shape for input in inputs])]
if odat is not None: if odat is not None:
odat.resize(shape) odat.resize(shape, refcheck = 0)
else: else:
odat = numpy.ndarray(shape, dtype = output.dtype) odat = numpy.ndarray(shape, dtype = output.type.dtype)
output_storage.append(odat) storage[0] = odat
output.data = odat
# the second calling form is used because in certain versions of numpy # the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults # the first (faster) version leads to segfaults
ufunc_args = [input.data for input in self.inputs]# + output_storage ufunc_args = inputs # + output_storage
results = self.ufunc(*ufunc_args) ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout)
if self.ufunc.nout == 1: results = [results] results = ufunc(*ufunc_args)
if ufunc.nout == 1: results = [results]
for result, storage in zip(results, output_storage): for result, storage in zip(results, output_storage):
if storage.shape: if storage[0].shape:
storage[:] = result storage[0][:] = result
else: else:
storage.itemset(result) storage[0].itemset(result)
# the following should be used instead of the previous loop, unfortunately it tends to segfault # the following should be used instead of the previous loop, unfortunately it tends to segfault
# self.ufunc(*(ufunc_args+output_storage)) # self.ufunc(*(ufunc_args+[s[0] for s in output_storage]))
def _c_all(self, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
_inames = inames _inames = inames
_onames = onames _onames = onames
inames = gof.utils.uniq(inames) inames = gof.utils.uniq(inames)
inputs = gof.utils.uniq(self.inputs) inputs = gof.utils.uniq(node.inputs)
defines = "" defines = ""
undefs = "" undefs = ""
dmap = self.destroy_map() dmap = dict([(node.outputs[i], [node.inputs[o]]) for i, o in self.inplace_pattern.items()])
idtypes = [input.dtype_specs()[1] for input in inputs] idtypes = [input.type.dtype_specs()[1] for input in inputs]
real = zip(*[(r, s, r.dtype_specs()[1]) real = zip(*[(r, s, r.type.dtype_specs()[1])
for r, s in zip(self.outputs, onames) if r not in dmap]) for r, s in zip(node.outputs, onames) if r not in dmap])
if real: if real:
real_outputs, real_onames, real_odtypes = real real_outputs, real_onames, real_odtypes = real
else: else:
real_outputs, real_onames, real_odtypes = [], [], [] real_outputs, real_onames, real_odtypes = [], [], []
aliased = zip(*[(r, s) aliased = zip(*[(r, s)
for (r, s) in zip(self.outputs, onames) if r in dmap]) for (r, s) in zip(node.outputs, onames) if r in dmap])
if aliased: if aliased:
aliased_outputs, aliased_onames = aliased aliased_outputs, aliased_onames = aliased
else: else:
aliased_outputs, aliased_onames = [], [] aliased_outputs, aliased_onames = [], []
orders = [[x and 'x' or i for i, x in enumerate(input.broadcastable)] for input in inputs] orders = [[x and 'x' or i for i, x in enumerate(input.type.broadcastable)] for input in inputs]
nnested = len(orders[0]) nnested = len(orders[0])
sub = dict(sub) sub = dict(sub)
for i, (input, iname) in enumerate(zip(inputs, inames)): for i, (input, iname) in enumerate(zip(inputs, inames)):
...@@ -387,9 +379,13 @@ class Broadcast(Op, Destroyer): ...@@ -387,9 +379,13 @@ class Broadcast(Op, Destroyer):
defines += "#define %(oname)s_i %(iname)s_i" % locals() defines += "#define %(oname)s_i %(iname)s_i" % locals()
undefs += "#undef %(oname)s_i" % locals() undefs += "#undef %(oname)s_i" % locals()
task_code = self.shadow.c_code(["%s_i" % s for s in _inames], task_code = self.scalar_op.c_code(Apply(self.scalar_op,
["%s_i" % s for s in onames], [Scalar(dtype = input.type.dtype)() for input in node.inputs],
sub) [Scalar(dtype = output.type.dtype)() for input in node.outputs]),
None,
["%s_i" % s for s in _inames],
["%s_i" % s for s in onames],
sub)
task_decl = "".join(["%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals() for name, dtype in zip(inames + list(real_onames), idtypes + list(real_odtypes))]) task_decl = "".join(["%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals() for name, dtype in zip(inames + list(real_onames), idtypes + list(real_odtypes))])
code = """ code = """
{ {
...@@ -406,72 +402,72 @@ class Broadcast(Op, Destroyer): ...@@ -406,72 +402,72 @@ class Broadcast(Op, Destroyer):
loop = cgen.make_loop(orders + [range(nnested)] * len(real_onames), idtypes + list(real_odtypes), all_code, sub) loop = cgen.make_loop(orders + [range(nnested)] * len(real_onames), idtypes + list(real_odtypes), all_code, sub)
return decl, checks, alloc, loop return decl, checks, alloc, loop
def c_code(self, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(inames, onames, sub)) code = "\n".join(self._c_all(node, name, inames, onames, sub))
return code return code
def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None, module_name = None): # def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None, module_name = None):
scalar_name = scalar_opclass.__name__ # scalar_name = scalar_opclass.__name__
if name is None: # if name is None:
name = scalar_name # name = scalar_name
if module_name is None: # if module_name is None:
module_name = 'elemwise.make_broadcast(%s, %s, %s)' % (scalar_name, inplace_pattern, repr(name)) # module_name = 'elemwise.make_broadcast(%s, %s, %s)' % (scalar_name, inplace_pattern, repr(name))
name = "New" # name = "New"
previous_doc = Broadcast.__doc__ # previous_doc = Broadcast.__doc__
scalar_doc = scalar_opclass.__doc__ or "" # scalar_doc = scalar_opclass.__doc__ or ""
if scalar_doc: # if scalar_doc:
scalar_doc = """ # scalar_doc = """
%(scalar_name)s documentation: # %(scalar_name)s documentation:
%(scalar_doc)s # %(scalar_doc)s
""" % locals() # """ % locals()
doc = """ # doc = """
Usage: %(name)s(*inputs) # Usage: %(name)s(*inputs)
Equivalent to: Broadcast(scalar.%(scalar_name)s, inputs, %(inplace_pattern)s) # Equivalent to: Broadcast(scalar.%(scalar_name)s, inputs, %(inplace_pattern)s)
Performs Scalar %(scalar_name)s on each element of the # Performs Scalar %(scalar_name)s on each element of the
input tensors. # input tensors.
%(scalar_doc)s # %(scalar_doc)s
Documention for Broadcast: # Documention for Broadcast:
================================================== # ==================================================
%(previous_doc)s # %(previous_doc)s
================================================== # ==================================================
""" % locals() # """ % locals()
class New(Broadcast): # class New(Broadcast):
__doc__ = doc # __doc__ = doc
def __init__(self, *inputs): # def __init__(self, *inputs):
Broadcast.__init__(self, scalar_opclass, inputs, inplace_pattern) # Broadcast.__init__(self, scalar_opclass, inputs, inplace_pattern)
def clone_with_new_inputs(self, *new_inputs): # def clone_with_new_inputs(self, *new_inputs):
return New(*new_inputs) # return New(*new_inputs)
@classmethod # @classmethod
def desc(cls): # def desc(cls):
return (Broadcast, scalar_opclass, tuple(inplace_pattern.items())) # return (Broadcast, scalar_opclass, tuple(inplace_pattern.items()))
New.__name__ = name # New.__name__ = name
New.__module__ = module_name # New.__module__ = module_name
return New # return New
def wrap_broadcast(op): # def wrap_broadcast(op):
def instantiate(*inputs): # def instantiate(*inputs):
inputs = map(astensor, inputs) # inputs = map(astensor, inputs)
target_length = max([len(input.broadcastable) for input in inputs]) # target_length = max([len(input.broadcastable) for input in inputs])
args = [] # args = []
for input in inputs: # for input in inputs:
length = len(input.broadcastable) # length = len(input.broadcastable)
difference = target_length - length # difference = target_length - length
if not difference: # if not difference:
args.append(input) # args.append(input)
else: # else:
args.append(DimShuffle(input, ['x']*difference + range(length)).out) # args.append(DimShuffle(input, ['x']*difference + range(length)).out)
return op(*args) # return op(*args)
instantiate.__name__ = "instantiate{%s}" % op.__name__ # instantiate.__name__ = "instantiate{%s}" % op.__name__
instantiate.__doc__ = op.__doc__ # instantiate.__doc__ = op.__doc__
return instantiate # return instantiate
...@@ -497,11 +493,10 @@ class CAReduce(Op): ...@@ -497,11 +493,10 @@ class CAReduce(Op):
over the reduced dimensions using the specified scalar op. over the reduced dimensions using the specified scalar op.
Examples: Examples:
CAReduce(Add, inputs) -> sum(inputs) CAReduce(add) -> sum
CAReduce(Mul, inputs) -> product(inputs) CAReduce(mul) -> product
CAReduce(Or, inputs) -> any(inputs) # not lazy CAReduce(_or) -> any # not lazy
CAReduce(And, inputs) -> all(inputs) # not lazy CAReduce(_and) -> all # not lazy
CAReduce(Xor, inputs) -> sum(inputs != 0) % 2
In order to optimize memory usage patterns, L{CAReduce} makes zero In order to optimize memory usage patterns, L{CAReduce} makes zero
guarantees on the order in which it iterates over the dimensions guarantees on the order in which it iterates over the dimensions
...@@ -510,74 +505,72 @@ class CAReduce(Op): ...@@ -510,74 +505,72 @@ class CAReduce(Op):
both commutative and associative (eg add, multiply, binary both commutative and associative (eg add, multiply, binary
or/and/xor - but not subtract, divide or power). or/and/xor - but not subtract, divide or power).
""" """
def __init__(self, scalar_opclass, inputs, axis = None):
inputs = map(astensor, inputs)
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(len(inputs) + 1)]) def __init__(self, scalar_op, axis = None):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
if self.shadow.nin != 2 or self.shadow.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.") raise NotImplementedError("CAReduce only supports binary functions with a single output.")
if len(inputs) != 1: self.scalar_op = scalar_op
raise TypeError("Only one argument expected.") if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.ufunc = numpy.frompyfunc(scalar_op.impl, 2, 1)
def make_node(self, input):
input = as_tensor(input)
axis = self.axis
if axis is None: if axis is None:
axis = range(len(inputs[0].broadcastable)) axis = range(len(input.type.broadcastable))
elif isinstance(axis, int): output = Tensor(dtype = input.type.dtype,
axis = [axis] broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
return Apply(self, [input], [output])
self.inputs = inputs
self.outputs = [Tensor(dtype = inputs[0].dtype,
broadcastable = [x for i, x in enumerate(inputs[0].broadcastable) if i not in axis])]
self.axis = axis
self.scalar_opclass = scalar_opclass
self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.axis))
def strdesc(self): def __str__(self):
if set(self.axis) != set(xrange(len(self.inputs[0].broadcastable))): if self.axis is not None:
return "Reduce{%s}{%s}" % (self.scalar_opclass.__name__, "".join(str(x) for x in self.axis)) return "Reduce{%s}{%s}" % (self.scalar_op, ", ".join(str(x) for x in self.axis))
else: else:
return "Reduce{%s}" % self.scalar_opclass.__name__ return "Reduce{%s}" % self.scalar_op
def clone_with_new_inputs(self, *new_inputs):
return CAReduce(self.scalar_opclass, new_inputs, self.axis)
def perform(self): def perform(self, node, (input, ), (output, )):
result = self.inputs[0].data axis = self.axis
to_reduce = reversed(sorted(self.axis)) if axis is None:
axis = range(input.ndim)
result = input
to_reduce = reversed(sorted(axis))
if to_reduce: if to_reduce:
for dimension in to_reduce: for dimension in to_reduce:
result = self.ufunc.reduce(result, dimension) result = self.ufunc.reduce(result, dimension)
self.outputs[0].data = result output[0] = numpy.asarray(result, dtype = node.outputs[0].type.dtype)
else: else:
self.outputs[0].data = numpy.copy(result) output[0] = numpy.copy(result)
def _c_all(self, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
input = self.inputs[0] input = node.inputs[0]
output = self.outputs[0] output = node.outputs[0]
iname = inames[0] iname = inames[0]
oname = onames[0] oname = onames[0]
idtype = input.dtype_specs()[1] idtype = input.type.dtype_specs()[1]
odtype = output.dtype_specs()[1] odtype = output.type.dtype_specs()[1]
tosum = self.axis axis = self.axis
if axis is None:
axis = range(len(input.type.broadcastable))
if tosum == (): if axis == ():
return Broadcast(scalar.Identity, (input, ))._c_all(inames, onames, sub) op = Elemwise(scalar.identity)
return op._c_all(op.make_node(input), name, inames, onames, sub)
# return Broadcast(scalar.Identity, (input, ))._c_all(inames, onames, sub)
order1 = [i for i in xrange(len(input.broadcastable)) if i not in tosum] order1 = [i for i in xrange(input.type.ndim) if i not in axis]
order = order1 + list(tosum) order = order1 + list(axis)
nnested = len(order1) nnested = len(order1)
sub = dict(sub) sub = dict(sub)
for i, (input, iname) in enumerate(zip(self.inputs, inames)): for i, (input, iname) in enumerate(zip(node.inputs, inames)):
sub['lv%i' % i] = iname sub['lv%i' % i] = iname
decl = cgen.make_declare([order], [idtype], sub) decl = cgen.make_declare([order], [idtype], sub)
...@@ -587,18 +580,23 @@ class CAReduce(Op): ...@@ -587,18 +580,23 @@ class CAReduce(Op):
i += 1 i += 1
sub['lv%i' % i] = oname sub['lv%i' % i] = oname
sub['olv'] = oname sub['olv'] = oname
alloc += cgen.make_declare([range(nnested) + ['x'] * len(tosum)], [odtype], dict(sub, lv0 = oname)) alloc += cgen.make_declare([range(nnested) + ['x'] * len(axis)], [odtype], dict(sub, lv0 = oname))
alloc += cgen.make_alloc([order1], odtype, sub) alloc += cgen.make_alloc([order1], odtype, sub)
alloc += cgen.make_checks([range(nnested) + ['x'] * len(tosum)], [odtype], dict(sub, lv0 = oname)) alloc += cgen.make_checks([range(nnested) + ['x'] * len(axis)], [odtype], dict(sub, lv0 = oname))
task0_decl = "%(dtype)s& %(name)s_i = *%(name)s_iter;\n%(name)s_i = %(identity)s;" % dict(dtype = odtype, task0_decl = "%(dtype)s& %(name)s_i = *%(name)s_iter;\n%(name)s_i = %(identity)s;" % dict(dtype = odtype,
name = onames[0], name = onames[0],
identity = self.shadow.identity) identity = self.scalar_op.identity)
task1_decl = "%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % dict(dtype = idtype, name = inames[0]) task1_decl = "%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % dict(dtype = idtype, name = inames[0])
task1_code = self.shadow.c_code(["%s_i" % onames[0], "%s_i" % inames[0]],
["%s_i" % onames[0]], task1_code = self.scalar_op.c_code(Apply(self.scalar_op,
sub) [Scalar(dtype = input.type.dtype)() for input in node.inputs*2],
[Scalar(dtype = output.type.dtype)() for input in node.outputs]),
None,
["%s_i" % onames[0], "%s_i" % inames[0]],
["%s_i" % onames[0]],
sub)
code1 = """ code1 = """
{ {
%(task1_decl)s %(task1_decl)s
...@@ -606,107 +604,100 @@ class CAReduce(Op): ...@@ -606,107 +604,100 @@ class CAReduce(Op):
} }
""" % locals() """ % locals()
if len(tosum) == 1: if len(axis) == 1:
all_code = [("", "")] * nnested + [(task0_decl, code1), ""] all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
else: else:
all_code = [("", "")] * nnested + [(task0_decl, "")] + [("", "")] * (len(tosum) - 2) + [("", code1), ""] all_code = [("", "")] * nnested + [(task0_decl, "")] + [("", "")] * (len(axis) - 2) + [("", code1), ""]
# if nnested: # if nnested:
# all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] # all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
# else: # else:
# all_code = [code] # all_code = [code]
# print [order, range(nnested) + ['x'] * len(tosum)] # print [order, range(nnested) + ['x'] * len(axis)]
loop = cgen.make_loop([order, range(nnested) + ['x'] * len(tosum)], [idtype, odtype], all_code, sub) loop = cgen.make_loop([order, range(nnested) + ['x'] * len(axis)], [idtype, odtype], all_code, sub)
return decl, checks, alloc, loop return decl, checks, alloc, loop
def c_code(self, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(inames, onames, sub)) code = "\n".join(self._c_all(node, name, inames, onames, sub))
# print code # print code
return code return code
def __str__(self):
input = self.inputs[0]
if len(input.broadcastable) == len(self.axis):
return "%s:%s(%s)" % (self.__class__.__name__,
self.scalar_opclass.__name__,
str(input))
else:
return "%s:%s(%s, axis = %s)" % (self.__class__.__name__,
self.scalar_opclass.__name__,
str(input),
self.axis)
def make_reduce(scalar_opclass, name = None): # def make_reduce(scalar_opclass, name = None):
if getattr(scalar_opclass, 'commutative', False) \ # if getattr(scalar_opclass, 'commutative', False) \
and getattr(scalar_opclass, 'associative', False): # and getattr(scalar_opclass, 'associative', False):
reducer = CAReduce # reducer = CAReduce
else: # else:
raise NotImplementedError("The scalar op class to reduce must be commutative and associative.") # raise NotImplementedError("The scalar op class to reduce must be commutative and associative.")
scalar_name = scalar_opclass.__name__ # scalar_name = scalar_opclass.__name__
if name is None: # if name is None:
name = "Reduce" + scalar_name # name = "Reduce" + scalar_name
previous_doc = reducer.__doc__ # previous_doc = reducer.__doc__
doc = """ # doc = """
Usage: %(name)s(input, axis) # Usage: %(name)s(input, axis)
Equivalent to: CAReduce(%(scalar_name)s, input, axis) # Equivalent to: CAReduce(%(scalar_name)s, input, axis)
Reduces the input over the specified axis. # Reduces the input over the specified axis.
Documention for CAReduce: # Documention for CAReduce:
================================================== # ==================================================
%(previous_doc)s # %(previous_doc)s
================================================== # ==================================================
""" % locals() # """ % locals()
class New(reducer): # class New(reducer):
__doc__ = doc # __doc__ = doc
def __init__(self, *inputs, **kwargs): # def __init__(self, *inputs, **kwargs):
reducer.__init__(self, scalar_opclass, inputs, kwargs.get('axis', None)) # reducer.__init__(self, scalar_opclass, inputs, kwargs.get('axis', None))
def clone_with_new_inputs(self, *new_inputs): # def clone_with_new_inputs(self, *new_inputs):
return New(*new_inputs, **dict(axis = self.axis)) # return New(*new_inputs, **dict(axis = self.axis))
def __str__(self): # def __str__(self):
input = self.inputs[0] # input = self.inputs[0]
if len(input.broadcastable) == len(self.axis): # if len(input.broadcastable) == len(self.axis):
return "%s(%s)" % (self.__class__.__name__, # return "%s(%s)" % (self.__class__.__name__,
str(input)) # str(input))
else: # else:
return "%s(%s, axis = %s)" % (self.__class__.__name__, # return "%s(%s, axis = %s)" % (self.__class__.__name__,
str(input), # str(input),
self.axis) # self.axis)
New.__name__ = name # New.__name__ = name
return New # return New
_Sum = make_reduce(scalar.Add, '_Sum') class Sum(CAReduce):
class Sum(_Sum): def __init__(self, axis = None):
__doc__ = _Sum.__doc__ CAReduce.__init__(self, scalar.add, axis)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if self.axis == (): gz = as_tensor(gz)
axis = self.axis
if axis is None:
axis = range(x.type.ndim)
if axis == ():
return gz, return gz,
new_dims = [] new_dims = []
i = 0 i = 0
for j, _ in enumerate(x.broadcastable): for j, _ in enumerate(x.type.broadcastable):
if j in self.axis: if j in axis:
new_dims.append('x') new_dims.append('x')
else: else:
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
return Broadcast(scalar.Second, (x, DimShuffle(gz, new_dims).out)).out, return Elemwise(scalar.second)(x, DimShuffle(gz.type.broadcastable, new_dims)(gz)),
def reduce(op): # def reduce(op):
if getattr(op, 'commutative', True) and getattr(op, 'associative', True): # if getattr(op, 'commutative', True) and getattr(op, 'associative', True):
reducer = CAReduce # reducer = CAReduce
else: # else:
raise NotImplementedError("The scalar op class to reduce must be commutative and associative.") # raise NotImplementedError("The scalar op class to reduce must be commutative and associative.")
def instantiate(*inputs): # def instantiate(*inputs):
return reducer(op, inputs, axis) # return reducer(op, inputs, axis)
return instantiate # return instantiate
import op, result, ext, link, env, features, toolbox, graph, cc, opt import op, type, ext, link, env, features, toolbox, graph, cc, opt
from op import * from op import *
from result import * from graph import Apply, Result, Constant, as_apply, as_result
from type import *
from ext import * from ext import *
from link import * from link import *
from env import * from env import *
......
...@@ -3,28 +3,15 @@ import unittest ...@@ -3,28 +3,15 @@ import unittest
from link import PerformLinker, Profiler from link import PerformLinker, Profiler
from cc import * from cc import *
from result import Result from type import Type
from graph import Result, as_result, Apply, Constant
from op import Op from op import Op
from env import Env from env import Env
class Double(Result): class TDouble(Type):
def filter(self, data):
return float(data)
def __init__(self, data, name = "oignon"):
Result.__init__(self, role = None, name = name)
assert isinstance(data, float)
self.data = data
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __copy__(self):
return Double(self.data, self.name)
# def c_is_simple(self): return True
def c_declare(self, name, sub): def c_declare(self, name, sub):
return "double %(name)s; void* %(name)s_bad_thing;" % locals() return "double %(name)s; void* %(name)s_bad_thing;" % locals()
...@@ -35,8 +22,8 @@ class Double(Result): ...@@ -35,8 +22,8 @@ class Double(Result):
//printf("Initializing %(name)s\\n"); //printf("Initializing %(name)s\\n");
""" % locals() """ % locals()
def c_literal(self): def c_literal(self, data):
return str(self.data) return str(data)
def c_extract(self, name, sub): def c_extract(self, name, sub):
return """ return """
...@@ -65,116 +52,119 @@ class Double(Result): ...@@ -65,116 +52,119 @@ class Double(Result):
free(%(name)s_bad_thing); free(%(name)s_bad_thing);
""" % locals() """ % locals()
tdouble = TDouble()
class MyOp(Op): def double(name):
return Result(tdouble, None, None, name = name)
nin = -1
def __init__(self, *inputs): class MyOp(Op):
def __init__(self, nin, name):
self.nin = nin
self.name = name
def make_node(self, *inputs):
assert len(inputs) == self.nin assert len(inputs) == self.nin
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, Double): if input.type is not tdouble:
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [double(self.name + "_R")]
self.outputs = [Double(0.0, self.__class__.__name__ + "_R")] return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def perform(self, node, inputs, (out, )):
out[0] = self.impl(*inputs)
class Unary(MyOp): class Unary(MyOp):
nin = 1 def __init__(self):
# def c_var_names(self): MyOp.__init__(self, 1, self.__class__.__name__)
# return [['x'], ['z']]
class Binary(MyOp): class Binary(MyOp):
nin = 2 def __init__(self):
# def c_var_names(self): MyOp.__init__(self, 2, self.__class__.__name__)
# return [['x', 'y'], ['z']]
class Add(Binary): class Add(Binary):
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals() return "%(z)s = %(x)s + %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = self.inputs[0].data + self.inputs[1].data return x + y
add = Add()
class Sub(Binary): class Sub(Binary):
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = -10 # erroneous return -10 # erroneous (most of the time)
sub = Sub()
class Mul(Binary): class Mul(Binary):
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals() return "%(z)s = %(x)s * %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = self.inputs[0].data * self.inputs[1].data return x * y
mul = Mul()
class Div(Binary): class Div(Binary):
def c_validate_update(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return """
if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s
}
""" % dict(locals(), **sub)
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = self.inputs[0].data / self.inputs[1].data return x / y
div = Div()
import modes
modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.build(Double(1.0, 'x')) x = double('x')
y = modes.build(Double(2.0, 'y')) y = double('y')
z = modes.build(Double(3.0, 'z')) z = double('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_CLinker(unittest.TestCase): class _test_CLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e])) lnk = CLinker(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
def test_orphan(self): # def test_orphan(self):
x, y, z = inputs() # x, y, z = inputs()
z.data = 4.12345678 # z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) # e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e])) # lnk = CLinker(Env([x, y], [e]))
fn = lnk.make_function() # fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9) # self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined # print lnk.code_gen()
# self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined
def test_literal_inlining(self): def test_literal_inlining(self):
x, y, z = inputs() x, y, z = inputs()
z.data = 4.12345678 z = Constant(tdouble, 4.12345678)
z.constant = True # this should tell the compiler to inline z as a literal
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e])) lnk = CLinker(Env([x, y], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9) self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
def test_single_op(self): def test_single_node(self):
x, y, z = inputs() x, y, z = inputs()
op = Add(x, y) node = add.make_node(x, y)
lnk = CLinker(op) lnk = CLinker(Env(node.inputs, node.outputs))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9) self.failUnless(fn(2.0, 7.0) == 9)
def test_dups(self): def test_dups(self):
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
op = Add(x, x) e = add(x, x)
lnk = CLinker(op) lnk = CLinker(Env([x, x], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4) self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
...@@ -183,7 +173,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -183,7 +173,7 @@ class _test_CLinker(unittest.TestCase):
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
e = add(mul(y, y), add(x, z)) e = add(mul(y, y), add(x, z))
lnk = CLinker(env([x, y, z], [e])) lnk = CLinker(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0) self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
...@@ -194,16 +184,25 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -194,16 +184,25 @@ class _test_OpWiseCLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e])) lnk = OpWiseCLinker(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
def test_constant(self):
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker(Env([y, z], [e]))
fn = lnk.make_function()
res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res)
class MyExc(Exception): class MyExc(Exception):
pass pass
def _my_checker(x, y): def _my_checker(x, y):
if x.data != y.data: if x[0] != y[0]:
raise MyExc("Output mismatch.", {'performlinker': x.data, 'clinker': y.data}) raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class _test_DualLinker(unittest.TestCase): class _test_DualLinker(unittest.TestCase):
...@@ -211,7 +210,7 @@ class _test_DualLinker(unittest.TestCase): ...@@ -211,7 +210,7 @@ class _test_DualLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(env([x, y, z], [e]), checker = _my_checker) lnk = DualLinker(Env([x, y, z], [e]), checker = _my_checker)
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
self.failUnless(res == 15.3, res) self.failUnless(res == 15.3, res)
...@@ -219,7 +218,7 @@ class _test_DualLinker(unittest.TestCase): ...@@ -219,7 +218,7 @@ class _test_DualLinker(unittest.TestCase):
def test_mismatch(self): def test_mismatch(self):
x, y, z = inputs() x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
lnk = DualLinker(g, checker = _my_checker) lnk = DualLinker(g, checker = _my_checker)
fn = lnk.make_function() fn = lnk.make_function()
...@@ -238,14 +237,14 @@ class _test_DualLinker(unittest.TestCase): ...@@ -238,14 +237,14 @@ class _test_DualLinker(unittest.TestCase):
else: else:
self.fail() self.fail()
def test_orphan(self): # def test_orphan(self):
x, y, z = inputs() # x, y, z = inputs()
x.data = 7.2 # x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python # e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(env([y, z], [e]), checker = _my_checker) # lnk = DualLinker(Env([y, z], [e]), checker = _my_checker)
fn = lnk.make_function() # fn = lnk.make_function()
res = fn(1.5, 3.0) # res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res) # self.failUnless(res == 15.3, res)
......
import unittest import unittest
from result import Result from type import Type
from graph import Result, as_result, Apply
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
...@@ -9,62 +10,64 @@ from ext import * ...@@ -9,62 +10,64 @@ from ext import *
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import EquivTool from toolbox import EquivTool
from _test_result import MyResult from copy import copy
class MyOp(Op): #from _test_result import MyResult
nin = -1
def __init__(self, *inputs):
assert len(inputs) == self.nin
for input in inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.inputs = inputs
self.outputs = [MyResult(self.__class__.__name__ + "_R")]
class Sigmoid(MyOp):
nin = 1
class TransposeView(MyOp, Viewer): class MyType(Type):
nin = 1
def view_map(self):
return {self.outputs[0]: [self.inputs[0]]}
class Add(MyOp): def filter(self, data):
nin = 2 return data
class AddInPlace(MyOp, Destroyer): def __eq__(self, other):
nin = 2 return isinstance(other, MyType)
def destroyed_inputs(self):
return self.inputs[:1]
class Dot(MyOp):
nin = 2
def MyResult(name):
return Result(MyType(), None, None, name = name)
# dtv_elim = PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x')
# AddCls = Add class MyOp(Op):
# AddInPlaceCls = AddInPlace
def __init__(self, nin, name, vmap = {}, dmap = {}):
self.nin = nin
self.name = name
self.destroy_map = dmap
self.view_map = vmap
def make_node(self, *inputs):
assert len(inputs) == self.nin
inputs = map(as_result, inputs)
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyResult(self.name + "_R")]
return Apply(self, inputs, outputs)
# a2i = OpSubOptimizer(Add, AddInPlace) def __str__(self):
# i2a = OpSubOptimizer(AddInPlace, Add) return self.name
# t2s = OpSubOptimizer(TransposeView, Sigmoid)
# s2t = OpSubOptimizer(Sigmoid, TransposeView)
sigmoid = MyOp(1, 'Sigmoid')
transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]})
add = MyOp(2, 'Add')
add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]})
dot = MyOp(2, 'Dot')
import modes
modes.make_constructors(globals()) #, name_filter = lambda x:x)
def inputs(): def inputs():
x = modes.build(MyResult('x')) x = MyResult('x')
y = modes.build(MyResult('y')) y = MyResult('y')
z = modes.build(MyResult('z')) z = MyResult('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): _Env = Env
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) def Env(inputs, outputs, validate = True):
e = _Env(inputs, outputs)
e.extend(EquivTool(e))
e.extend(DestroyHandler(e), validate = validate)
return e
class FailureWatch: class FailureWatch:
...@@ -82,7 +85,7 @@ class _test_all(unittest.TestCase): ...@@ -82,7 +85,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = add(add_in_place(x, y), add_in_place(x, y)) e = add(add_in_place(x, y), add_in_place(x, y))
try: try:
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
self.fail() self.fail()
except InconsistencyError, e: except InconsistencyError, e:
pass pass
...@@ -90,10 +93,10 @@ class _test_all(unittest.TestCase): ...@@ -90,10 +93,10 @@ class _test_all(unittest.TestCase):
def test_multi_destroyers_through_views(self): def test_multi_destroyers_through_views(self):
x, y, z = inputs() x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x)) e = dot(add(transpose_view(z), y), add(z, x))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(Add, AddInPlace, fail).optimize(g) OpSubOptimizer(add, add_in_place, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once assert fail.failures == 1 # should have succeeded once and failed once
...@@ -102,7 +105,7 @@ class _test_all(unittest.TestCase): ...@@ -102,7 +105,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e1 = add(x, y) e1 = add(x, y)
e2 = add(y, x) e2 = add(y, x)
g = env([x,y,z], [e1, e2]) g = Env([x,y,z], [e1, e2])
chk = g.checkpoint() chk = g.checkpoint()
assert g.consistent() assert g.consistent()
g.replace(e1, add_in_place(x, y)) g.replace(e1, add_in_place(x, y))
...@@ -126,30 +129,30 @@ class _test_all(unittest.TestCase): ...@@ -126,30 +129,30 @@ class _test_all(unittest.TestCase):
def test_long_destroyers_loop(self): def test_long_destroyers_loop(self):
x, y, z = inputs() x, y, z = inputs()
e = dot(dot(add_in_place(x,y), add_in_place(y,z)), add(z,x)) e = dot(dot(add_in_place(x,y), add_in_place(y,z)), add(z,x))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
OpSubOptimizer(Add, AddInPlace).optimize(g) OpSubOptimizer(add, add_in_place).optimize(g)
assert g.consistent() assert g.consistent()
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that! assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x)) e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x))
try: try:
g2 = env([x,y,z], [e2]) g2 = Env([x,y,z], [e2])
self.fail() self.fail()
except InconsistencyError: except InconsistencyError:
pass pass
def test_usage_loop(self): def test_usage_loop(self):
x, y, z = inputs() x, y, z = inputs()
g = env([x,y,z], [dot(add_in_place(x, z), x)], False) g = Env([x,y,z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
OpSubOptimizer(AddInPlace, Add).optimize(g) # replace AddInPlace with Add OpSubOptimizer(add_in_place, add).optimize(g) # replace add_in_place with add
assert g.consistent() assert g.consistent()
def test_usage_loop_through_views(self): def test_usage_loop_through_views(self):
x, y, z = inputs() x, y, z = inputs()
aip = add_in_place(x, y) aip = add_in_place(x, y)
e = dot(aip, transpose_view(x)) e = dot(aip, transpose_view(x))
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(aip, add(x, z)) g.replace(aip, add(x, z))
assert g.consistent() assert g.consistent()
...@@ -158,7 +161,7 @@ class _test_all(unittest.TestCase): ...@@ -158,7 +161,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(transpose_view(sigmoid(x)))) e0 = transpose_view(transpose_view(transpose_view(sigmoid(x))))
e = dot(add_in_place(x,y), transpose_view(e0)) e = dot(add_in_place(x,y), transpose_view(e0))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() # because sigmoid can do the copy assert g.consistent() # because sigmoid can do the copy
g.replace(e0, x, False) g.replace(e0, x, False)
assert not g.consistent() # we cut off the path to the sigmoid assert not g.consistent() # we cut off the path to the sigmoid
...@@ -166,22 +169,23 @@ class _test_all(unittest.TestCase): ...@@ -166,22 +169,23 @@ class _test_all(unittest.TestCase):
def test_usage_loop_insert_views(self): def test_usage_loop_insert_views(self):
x, y, z = inputs() x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x)))))) e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x))))))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(Sigmoid, TransposeView, fail).optimize(g) OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # it must keep one sigmoid in the long sigmoid chain assert fail.failures == 1 # it must keep one sigmoid in the long sigmoid chain
def test_misc(self): def test_misc(self):
x, y, z = inputs() x, y, z = inputs()
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
chk = g.checkpoint() chk = g.checkpoint()
PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x').optimize(g) PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
g.replace(g.equiv(e), add(x,y)) g.replace(g.equiv(e), add(x,y))
print g
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False) g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
...@@ -193,8 +197,10 @@ class _test_all(unittest.TestCase): ...@@ -193,8 +197,10 @@ class _test_all(unittest.TestCase):
def test_indestructible(self): def test_indestructible(self):
x, y, z = inputs() x, y, z = inputs()
x.indestructible = True x.indestructible = True
x = copy(x)
assert x.indestructible # checking if indestructible survives the copy!
e = add_in_place(x, y) e = add_in_place(x, y)
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e, add(x, y)) g.replace(e, add(x, y))
assert g.consistent() assert g.consistent()
...@@ -204,7 +210,7 @@ class _test_all(unittest.TestCase): ...@@ -204,7 +210,7 @@ class _test_all(unittest.TestCase):
x.indestructible = True x.indestructible = True
tv = transpose_view(x) tv = transpose_view(x)
e = add_in_place(tv, y) e = add_in_place(tv, y)
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv, sigmoid(x)) g.replace(tv, sigmoid(x))
assert g.consistent() assert g.consistent()
...@@ -215,7 +221,7 @@ class _test_all(unittest.TestCase): ...@@ -215,7 +221,7 @@ class _test_all(unittest.TestCase):
e2 = transpose_view(transpose_view(e1)) e2 = transpose_view(transpose_view(e1))
e3 = add_in_place(e2, y) e3 = add_in_place(e2, y)
e4 = add_in_place(e1, z) e4 = add_in_place(e1, z)
g = env([x,y,z], [e3, e4], False) g = Env([x,y,z], [e3, e4], False)
assert not g.consistent() assert not g.consistent()
g.replace(e2, transpose_view(x), False) g.replace(e2, transpose_view(x), False)
assert not g.consistent() assert not g.consistent()
...@@ -224,7 +230,7 @@ class _test_all(unittest.TestCase): ...@@ -224,7 +230,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e0 = add_in_place(x, y) e0 = add_in_place(x, y)
e = dot(sigmoid(e0), transpose_view(x)) e = dot(sigmoid(e0), transpose_view(x))
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(x, y) new_e0 = add(x, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0, False)
...@@ -236,7 +242,7 @@ class _test_all(unittest.TestCase): ...@@ -236,7 +242,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(x) e0 = transpose_view(x)
e = dot(sigmoid(add_in_place(x, y)), e0) e = dot(sigmoid(add_in_place(x, y)), e0)
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(e0, y) new_e0 = add(e0, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0, False)
......
...@@ -4,21 +4,18 @@ import unittest ...@@ -4,21 +4,18 @@ import unittest
from graph import * from graph import *
from op import Op from op import Op
from result import Result from type import Type
from graph import Result
class MyResult(Result):
class MyType(Type):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
Result.__init__(self, role = None )
self.data = [self.thingy]
def __eq__(self, other): def __eq__(self, other):
return self.same_properties(other) return isinstance(other, MyType) and other.thingy == self.thingy
def same_properties(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return str(self.thingy)
...@@ -26,37 +23,78 @@ class MyResult(Result): ...@@ -26,37 +23,78 @@ class MyResult(Result):
def __repr__(self): def __repr__(self):
return str(self.thingy) return str(self.thingy)
def MyResult(thingy):
return Result(MyType(thingy), None, None)
class MyOp(Op): class MyOp(Op):
def __init__(self, *inputs): def make_node(self, *inputs):
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input.type, MyType):
print input, input.type, type(input), type(input.type)
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [MyResult(sum([input.type.thingy for input in inputs]))]
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] return Apply(self, inputs, outputs)
def __str__(self):
return self.__class__.__name__
MyOp = MyOp()
# class MyResult(Result):
# def __init__(self, thingy):
# self.thingy = thingy
# Result.__init__(self, role = None )
# self.data = [self.thingy]
# def __eq__(self, other):
# return self.same_properties(other)
# def same_properties(self, other):
# return isinstance(other, MyResult) and other.thingy == self.thingy
# def __str__(self):
# return str(self.thingy)
# def __repr__(self):
# return str(self.thingy)
# class MyOp(Op):
# def __init__(self, *inputs):
# for input in inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.inputs = inputs
# self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(node.outputs) == set([r1, r2])
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) assert inputs(node2.outputs) == set([r1, r2, r5])
def test_unreached_inputs(self): def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
try: try:
# function doesn't raise if we put False instead of True # function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, node2.outputs[0]], node.outputs, True)
self.fail() self.fail()
except Exception, e: except Exception, e:
if e[0] is results_and_orphans.E_unreached: if e[0] is results_and_orphans.E_unreached:
...@@ -68,70 +106,84 @@ class _test_orphans(unittest.TestCase): ...@@ -68,70 +106,84 @@ class _test_orphans(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5]) assert orphans([r1, r2], node2.outputs) == set([r5])
class _test_as_string(unittest.TestCase): class _test_as_string(unittest.TestCase):
leaf_formatter = str leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
", ".join(argstrings)) ", ".join(argstrings))
def str(self, inputs, outputs):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] assert self.str([r1, r2], node.outputs) == ["MyOp(1, 2)"]
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"] assert self.str(node.outputs, node2.outputs) == ["MyOp(3, 3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"] assert self.str(node2.inputs, node2.outputs) == ["MyOp(3, 3)"]
class _test_clone(unittest.TestCase): class _test_clone(unittest.TestCase):
leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
", ".join(argstrings))
def str(self, inputs, outputs):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
def test_accurate(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], node.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert self.str([r1, r2], new) == ["MyOp(1, 2)"]
def test_copy(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs) new = clone([r1, r2, r5], node2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] # the new output is like the old one but not the same object assert node2.outputs[0].type == new[0].type and node2.outputs[0] is not new[0] # the new output is like the old one but not the same object
assert op2 is not new[0].owner # the new output has a new owner assert node2 is not new[0].owner # the new output has a new owner
assert new[0].owner.inputs[1] is r5 # the inputs are not copied assert new[0].owner.inputs[1] is r5 # the inputs are not copied
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] # check that we copied deeper too assert new[0].owner.inputs[0].type == node.outputs[0].type and new[0].owner.inputs[0] is not node.outputs[0] # check that we copied deeper too
def test_not_destructive(self): def test_not_destructive(self):
# Checks that manipulating a cloned graph leaves the original unchanged. # Checks that manipulating a cloned graph leaves the original unchanged.
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5) node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs) new = clone([r1, r2, r5], node.outputs)
new_op = new[0].owner new_node = new[0].owner
new_op.inputs = MyResult(7), MyResult(8) new_node.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"] assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(7, 8)"]
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"] assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(1, 2), 5)"]
......
...@@ -2,70 +2,67 @@ ...@@ -2,70 +2,67 @@
import unittest import unittest
from result import Result from graph import Result, as_result, Apply
from type import Type
from op import Op from op import Op
from env import Env from env import Env
from link import * from link import *
from _test_result import Double #from _test_result import Double
class MyOp(Op):
nin = -1 class TDouble(Type):
def filter(self, data):
return float(data)
tdouble = TDouble()
def double(name):
return Result(tdouble, None, None, name = name)
def __init__(self, *inputs): class MyOp(Op):
def __init__(self, nin, name, impl = None):
self.nin = nin
self.name = name
if impl:
self.impl = impl
def make_node(self, *inputs):
assert len(inputs) == self.nin assert len(inputs) == self.nin
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, Double): if input.type is not tdouble:
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [double(self.name + "_R")]
self.outputs = [Double(0.0, self.__class__.__name__ + "_R")] return Apply(self, inputs, outputs)
def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
class Unary(MyOp): def __str__(self):
nin = 1 return self.name
def perform(self, node, inputs, (out, )):
out[0] = self.impl(*inputs)
class Binary(MyOp): add = MyOp(2, 'Add', lambda x, y: x + y)
nin = 2 sub = MyOp(2, 'Sub', lambda x, y: x - y)
mul = MyOp(2, 'Mul', lambda x, y: x * y)
div = MyOp(2, 'Div', lambda x, y: x / y)
def notimpl(self, x):
class Add(Binary): raise NotImplementedError()
def impl(self, x, y):
return x + y
class Sub(Binary):
def impl(self, x, y):
return x - y
class Mul(Binary):
def impl(self, x, y):
return x * y
class Div(Binary):
def impl(self, x, y):
return x / y
class RaiseErr(Unary): raise_err = MyOp(1, 'RaiseErr', notimpl)
def impl(self, x):
raise NotImplementedError()
import modes
modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.build(Double(1.0, 'x')) x = double('x')
y = modes.build(Double(2.0, 'y')) y = double('y')
z = modes.build(Double(3.0, 'z')) z = double('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker(env) lnk = PerformLinker(env)
return lnk return lnk
...@@ -73,58 +70,79 @@ def perform_linker(env): ...@@ -73,58 +70,79 @@ def perform_linker(env):
class _test_PerformLinker(unittest.TestCase): class _test_PerformLinker(unittest.TestCase):
def test_thunk_inplace(self): def test_thunk(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True) fn, i, o = perform_linker(Env([x, y, z], [e])).make_thunk()
fn() i[0].data = 1
assert e.data == 1.5 i[1].data = 2
def test_thunk_not_inplace(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn() fn()
assert o[0].data == 1.5 assert o[0].data == 1.5
assert e.data != 1.5
def test_function(self): def test_function(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function() fn = perform_linker(Env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5 assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5 # not inplace
def test_constant(self):
x, y, z = inputs()
y = Constant(tdouble, 2.0)
e = mul(add(x, y), div(x, y))
fn = perform_linker(Env([x], [e])).make_function()
assert fn(1.0) == 1.5
def test_input_output_same(self): def test_input_output_same(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a,d = add(x,y), div(x,y)
e = mul(a,d) e = mul(a,d)
fn = perform_linker(env([e], [e])).make_function() fn = perform_linker(Env([e], [e])).make_function()
self.failUnless(1 is fn(1)) self.failUnless(1.0 is fn(1.0))
def test_input_dependency0(self): def test_input_dependency0(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a,d = add(x,y), div(x,y)
e = mul(a,d) e = mul(a,d)
fn = perform_linker(env([x, a], [e])).make_function() fn = perform_linker(Env([x, y, a], [e])).make_function()
self.failUnless(fn(1.0,9.0) == 4.5) self.failUnless(fn(1.0,2.0,9.0) == 4.5)
def test_skiphole(self): def test_skiphole(self):
x,y,z = inputs() x,y,z = inputs()
a = add(x,y) a = add(x,y)
r = RaiseErr(a).out r = raise_err(a)
e = add(r,a) e = add(r,a)
fn = perform_linker(env([x, y,r], [e])).make_function() fn = perform_linker(Env([x, y,r], [e])).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5) self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def test_disconnected_input_output(self): # def test_disconnected_input_output(self):
x,y,z = inputs() # x,y,z = inputs()
a = add(x,y) # a = add(x,y)
a.data = 3.0 # simulate orphan calculation # a.data = 3.0 # simulate orphan calculation
fn = perform_linker(env([z], [a])).make_function(inplace=True) # fn = perform_linker(env([z], [a])).make_function(inplace=True)
self.failUnless(fn(1.0) == 3.0) # self.failUnless(fn(1.0) == 3.0)
self.failUnless(fn(2.0) == 3.0) # self.failUnless(fn(2.0) == 3.0)
# def test_thunk_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(Env([x, y, z], [e])).make_thunk(True)
# fn()
# assert e.data == 1.5
# def test_thunk_not_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
# fn()
# assert o[0].data == 1.5
# assert e.data != 1.5
# def test_function(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn = perform_linker(env([x, y, z], [e])).make_function()
# assert fn(1.0, 2.0, 3.0) == 1.5
# assert e.data != 1.5 # not inplace
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
from modes import * from modes import *
from result import Result from graph import Result
from op import Op from op import Op
from env import Env from env import Env
...@@ -86,53 +86,53 @@ def env(inputs, outputs, validate = True): ...@@ -86,53 +86,53 @@ def env(inputs, outputs, validate = True):
return Env(inputs, outputs, features = [], consistency_check = validate) return Env(inputs, outputs, features = [], consistency_check = validate)
class _test_Modes(unittest.TestCase): # class _test_Modes(unittest.TestCase):
def test_0(self): # def test_0(self):
x, y, z = inputs(build) # x, y, z = inputs(build)
e = add(add(x, y), z) # e = add(add(x, y), z)
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" # assert str(g) == "[Add(Add(x, y), z)]"
assert e.data == 0.0 # assert e.data == 0.0
def test_1(self): # def test_1(self):
x, y, z = inputs(build_eval) # x, y, z = inputs(build_eval)
e = add(add(x, y), z) # e = add(add(x, y), z)
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" # assert str(g) == "[Add(Add(x, y), z)]"
assert e.data == 6.0 # assert e.data == 6.0
def test_2(self): # def test_2(self):
x, y, z = inputs(eval) # x, y, z = inputs(eval)
e = add(add(x, y), z) # e = add(add(x, y), z)
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add_R]" # assert str(g) == "[Add_R]"
assert e.data == 6.0 # assert e.data == 6.0
def test_3(self): # def test_3(self):
x, y, z = inputs(build) # x, y, z = inputs(build)
e = x + y + z # e = x + y + z
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" # assert str(g) == "[Add(Add(x, y), z)]"
assert e.data == 0.0 # assert e.data == 0.0
def test_4(self): # def test_4(self):
x, y, z = inputs(build_eval) # x, y, z = inputs(build_eval)
e = x + 34.0 # e = x + 34.0
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(x, oignon)]" # assert str(g) == "[Add(x, oignon)]"
assert e.data == 35.0 # assert e.data == 35.0
def test_5(self): # def test_5(self):
xb, yb, zb = inputs(build) # xb, yb, zb = inputs(build)
xe, ye, ze = inputs(eval) # xe, ye, ze = inputs(eval)
try: # try:
e = xb + ye # e = xb + ye
except TypeError: # except TypeError:
# Trying to add inputs from different modes is forbidden # # Trying to add inputs from different modes is forbidden
pass # pass
else: # else:
raise Exception("Expected an error.") # raise Exception("Expected an error.")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,21 +2,19 @@ ...@@ -2,21 +2,19 @@
import unittest import unittest
from copy import copy from copy import copy
from op import * from op import *
from result import Result from type import Type, Generic
from graph import Apply, as_result
#from result import Result
class MyResult(Result):
class MyType(Type):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
Result.__init__(self, role = None)
self.data = [self.thingy]
def __eq__(self, other): def __eq__(self, other):
return self.same_properties(other) return type(other) == type(self) and other.thingy == self.thingy
def same_properties(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return str(self.thingy)
...@@ -27,32 +25,35 @@ class MyResult(Result): ...@@ -27,32 +25,35 @@ class MyResult(Result):
class MyOp(Op): class MyOp(Op):
def __init__(self, *inputs): def make_node(self, *inputs):
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input.type, MyType):
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [MyType(sum([input.type.thingy for input in inputs]))()]
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] return Apply(self, inputs, outputs)
MyOp = MyOp()
class _test_Op(unittest.TestCase): class _test_Op(unittest.TestCase):
# Sanity tests # Sanity tests
def test_sanity_0(self): def test_sanity_0(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyType(1)(), MyType(2)()
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
assert op.inputs == [r1, r2] # Are the inputs what I provided? assert [x for x in node.inputs] == [r1, r2] # Are the inputs what I provided?
assert op.outputs == [MyResult(3)] # Are the outputs what I expect? assert [x.type for x in node.outputs] == [MyType(3)] # Are the outputs what I expect?
assert op.outputs[0].owner is op and op.outputs[0].index == 0 assert node.outputs[0].owner is node and node.outputs[0].index == 0
# validate_update # validate
def test_validate_update(self): def test_validate(self):
try: try:
MyOp(Result(), MyResult(1)) # MyOp requires MyResult instances MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances
except Exception, e:
assert str(e) == "Error 1"
else:
raise Exception("Expected an exception") raise Exception("Expected an exception")
except Exception, e:
if str(e) != "Error 1":
raise
......
import unittest import unittest
from result import Result from graph import Result, as_result, Apply, Constant
from op import Op from op import Op
from ext import Destroyer from ext import Destroyer
from opt import * from opt import *
...@@ -9,71 +9,61 @@ from env import Env ...@@ -9,71 +9,61 @@ from env import Env
from toolbox import * from toolbox import *
class MyResult(Result):
def __init__(self, name): class MyType(Type):
Result.__init__(self, role = None, name = name)
self.data = [1000]
def __str__(self): def filter(self, data):
return self.name return data
def __repr__(self): def __eq__(self, other):
return self.name return isinstance(other, MyType)
def desc(self):
return self.data
def MyResult(name):
return Result(MyType(), None, None, name = name)
class MyOp(Op): class MyOp(Op):
def __init__(self, *inputs): def __init__(self, name, dmap = {}, x = None):
self.name = name
self.destroy_map = dmap
self.x = x
def make_node(self, *inputs):
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input.type, MyType):
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [MyType()()]
self.outputs = [MyResult(self.__class__.__name__ + "_R")] return Apply(self, inputs, outputs)
class Op1(MyOp): def __str__(self):
pass return self.name
class Op2(MyOp):
pass
class Op3(MyOp):
pass
class Op4(MyOp): def __eq__(self, other):
pass return self is other or isinstance(other, MyOp) and self.x is not None and self.x == other.x
class OpD(MyOp, Destroyer): def __hash__(self):
def destroyed_inputs(self): return self.x if self.x is not None else id(self)
return [self.inputs[0]]
class OpZ(MyOp): op1 = MyOp('Op1')
def __init__(self, x, y, a, b): op2 = MyOp('Op2')
self.a = a op3 = MyOp('Op3')
self.b = b op4 = MyOp('Op4')
MyOp.__init__(self, x, y) op_d = MyOp('OpD', {0: [0]})
def desc(self):
return (self.a, self.b)
op_y = MyOp('OpY', x = 1)
op_z = MyOp('OpZ', x = 1)
import modes
modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.build(MyResult('x')) x = MyResult('x')
y = modes.build(MyResult('y')) y = MyResult('y')
z = modes.build(MyResult('z')) z = MyResult('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True):
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
class _test_PatternOptimizer(unittest.TestCase): class _test_PatternOptimizer(unittest.TestCase):
...@@ -81,34 +71,34 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -81,34 +71,34 @@ class _test_PatternOptimizer(unittest.TestCase):
# replacing the whole graph # replacing the whole graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '2'), '3'), PatternOptimizer((op1, (op2, '1', '2'), '3'),
(Op4, '3', '2')).optimize(g) (op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]" assert str(g) == "[Op4(z, y)]"
def test_nested_out_pattern(self): def test_nested_out_pattern(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(x, y) e = op1(x, y)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, '1', '2'), PatternOptimizer((op1, '1', '2'),
(Op4, (Op1, '1'), (Op2, '2'), (Op3, '1', '2'))).optimize(g) (op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g)
assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]" assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
def test_unification_1(self): def test_unification_1(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, x), z) # the arguments to op2 are the same e = op1(op2(x, x), z) # the arguments to op2 are the same
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '1'), '2'), # they are the same in the pattern PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(Op4, '2', '1')).optimize(g) (op4, '2', '1')).optimize(g)
# So the replacement should occur # So the replacement should occur
assert str(g) == "[Op4(z, x)]" assert str(g) == "[Op4(z, x)]"
def test_unification_2(self): def test_unification_2(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) # the arguments to op2 are different e = op1(op2(x, y), z) # the arguments to op2 are different
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '1'), '2'), # they are the same in the pattern PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(Op4, '2', '1')).optimize(g) (op4, '2', '1')).optimize(g)
# The replacement should NOT occur # The replacement should NOT occur
assert str(g) == "[Op1(Op2(x, y), z)]" assert str(g) == "[Op1(Op2(x, y), z)]"
...@@ -116,9 +106,9 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -116,9 +106,9 @@ class _test_PatternOptimizer(unittest.TestCase):
# replacing inside the graph # replacing inside the graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(Op1, '2', '1')).optimize(g) (op1, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]" assert str(g) == "[Op1(Op1(y, x), z)]"
def test_no_recurse(self): def test_no_recurse(self):
...@@ -126,18 +116,18 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -126,18 +116,18 @@ class _test_PatternOptimizer(unittest.TestCase):
# it should do the replacement and stop # it should do the replacement and stop
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(Op2, '2', '1')).optimize(g) (op2, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]" assert str(g) == "[Op1(Op2(y, x), z)]"
def test_multiple(self): def test_multiple(self):
# it should replace all occurrences of the pattern # it should replace all occurrences of the pattern
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(y, z)) e = op1(op2(x, y), op2(x, y), op2(y, z))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(Op4, '1')).optimize(g) (op4, '1')).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]" assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
def test_nested_even(self): def test_nested_even(self):
...@@ -145,120 +135,128 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -145,120 +135,128 @@ class _test_PatternOptimizer(unittest.TestCase):
# should work # should work
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(x)))) e = op1(op1(op1(op1(x))))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')), PatternOptimizer((op1, (op1, '1')),
'1').optimize(g) '1').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
def test_nested_odd(self): def test_nested_odd(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')), PatternOptimizer((op1, (op1, '1')),
'1').optimize(g) '1').optimize(g)
assert str(g) == "[Op1(x)]" assert str(g) == "[Op1(x)]"
def test_expand(self): def test_expand(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, '1'), PatternOptimizer((op1, '1'),
(Op2, (Op1, '1'))).optimize(g) (op2, (op1, '1'))).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]" assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
def test_ambiguous(self): # def test_ambiguous(self):
# this test is known to fail most of the time # # this test is known to fail most of the time
# the reason is that PatternOptimizer doesn't go through # # the reason is that PatternOptimizer doesn't go through
# the ops in topological order. The order is random and # # the ops in topological order. The order is random and
# it does not visit ops that it creates. # # it does not visit ops that it creates.
x, y, z = inputs() # x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) # e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e]) # g = Env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')), # PatternOptimizer((op1, (op1, '1')),
(Op1, '1')).optimize(g) # (op1, '1')).optimize(g)
assert str(g) == "[Op1(x)]" # assert str(g) == "[Op1(x)]"
def test_constant_unification(self): def test_constant_unification(self):
x, y, z = inputs() x = Constant(MyType(), 2, name = 'x')
x.constant = True y = MyResult('y')
x.value = 2 z = Constant(MyType(), 2, name = 'z')
z.constant = True
z.value = 2
e = op1(op1(x, y), y) e = op1(op1(x, y), y)
g = env([y], [e]) g = Env([y], [e])
PatternOptimizer((Op1, z, '1'), PatternOptimizer((op1, z, '1'),
(Op2, '1', z)).optimize(g) (op2, '1', z)).optimize(g)
assert str(g) == "[Op1(Op2(y, z), y)]" assert str(g) == "[Op1(Op2(y, z), y)]"
def test_constraints(self): def test_constraints(self):
x, y, z = inputs() x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y))) e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
def constraint(env, r): def constraint(env, r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return isinstance(r.owner, Op2) return r.owner.op == op2
PatternOptimizer((Op1, {'pattern': '1', PatternOptimizer((op1, {'pattern': '1',
'constraint': constraint}), 'constraint': constraint}),
(Op3, '1')).optimize(g) (op3, '1')).optimize(g)
assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]" assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]"
def test_match_same(self): def test_match_same(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(x, x) e = op1(x, x)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op1, 'x', 'y'), PatternOptimizer((op1, 'x', 'y'),
(Op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, x)]" assert str(g) == "[Op3(x, x)]"
def test_match_same_illegal(self): def test_match_same_illegal(self):
x, y, z = inputs() x, y, z = inputs()
e = op2(op1(x, x), op1(x, y)) e = op2(op1(x, x), op1(x, y))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
def constraint(env, r): def constraint(env, r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return r.owner.inputs[0] is not r.owner.inputs[1] return r.owner.inputs[0] is not r.owner.inputs[1]
PatternOptimizer({'pattern': (Op1, 'x', 'y'), PatternOptimizer({'pattern': (op1, 'x', 'y'),
'constraint': constraint}, 'constraint': constraint},
(Op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op2(Op1(x, x), Op3(x, y))]" assert str(g) == "[Op2(Op1(x, x), Op3(x, y))]"
def test_multi(self): def test_multi(self):
x, y, z = inputs() x, y, z = inputs()
e0 = op1(x, y) e0 = op1(x, y)
e = op3(op4(e0), e0) e = op3(op4(e0), e0)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((Op4, (Op1, 'x', 'y')), PatternOptimizer((op4, (op1, 'x', 'y')),
(Op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]" assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
def test_multi_ingraph(self):
x, y, z = inputs()
e0 = op1(x, y)
e = op4(e0, e0)
g = env([x, y, z], [e])
PatternOptimizer((Op4, (Op1, 'x', 'y'), (Op1, 'x', 'y')),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, y)]"
class _test_PatternDescOptimizer(unittest.TestCase):
def test_replace_output(self): def test_eq(self):
# replacing the whole graph # replacing the whole graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op_y(x, y), z)
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternDescOptimizer((Op1, (Op2, '1', '2'), '3'), PatternOptimizer((op1, (op_z, '1', '2'), '3'),
(Op4, '3', '2')).optimize(g) (op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]" assert str(g) == "[Op4(z, y)]"
def test_desc(self): # def test_multi_ingraph(self):
x, y, z = inputs() # # known to fail
e = op1(op_z(x, y, 37, 88), op2(op_z(y, z, 1, 7))) # x, y, z = inputs()
g = env([x, y, z], [e]) # e0 = op1(x, y)
PatternDescOptimizer(((37, 88), '1', '2'), # e = op4(e0, e0)
(Op3, '2', '1')).optimize(g) # g = Env([x, y, z], [e])
assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]" # PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g)
# assert str(g) == "[Op3(x, y)]"
# class _test_PatternDescOptimizer(unittest.TestCase):
# def test_replace_output(self):
# # replacing the whole graph
# x, y, z = inputs()
# e = op1(op2(x, y), z)
# g = env([x, y, z], [e])
# PatternDescOptimizer((Op1, (Op2, '1', '2'), '3'),
# (Op4, '3', '2')).optimize(g)
# assert str(g) == "[Op4(z, y)]"
# def test_eq(self):
# x, y, z = inputs()
# e = op1(op_y(x, y, 37, 88), op2(op_y(y, z, 1, 7)))
# g = env([x, y, z], [e])
# PatternDescOptimizer((op_z, '1', '2'),
# (op3, '2', '1')).optimize(g)
# assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
class _test_OpSubOptimizer(unittest.TestCase): class _test_OpSubOptimizer(unittest.TestCase):
...@@ -266,15 +264,15 @@ class _test_OpSubOptimizer(unittest.TestCase): ...@@ -266,15 +264,15 @@ class _test_OpSubOptimizer(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
OpSubOptimizer(Op1, Op2).optimize(g) OpSubOptimizer(op1, op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]" assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_straightforward_2(self): def test_straightforward_2(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x), op3(y), op4(z)) e = op1(op2(x), op3(y), op4(z))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
OpSubOptimizer(Op3, Op4).optimize(g) OpSubOptimizer(op3, op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]" assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
...@@ -283,18 +281,16 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -283,18 +281,16 @@ class _test_MergeOptimizer(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]" assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
def test_constant_merging(self): def test_constant_merging(self):
x, y, z = inputs() x = MyResult('x')
y.data = 2 y = Constant(MyType(), 2, name = 'y')
y.constant = True z = Constant(MyType(), 2, name = 'z')
z.data = 2.0
z.constant = True
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
...@@ -303,14 +299,14 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -303,14 +299,14 @@ class _test_MergeOptimizer(unittest.TestCase):
def test_deep_merge(self): def test_deep_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]" assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]"
def test_no_merge(self): def test_no_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x))) e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]" assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]"
...@@ -318,7 +314,7 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -318,7 +314,7 @@ class _test_MergeOptimizer(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e1 = op3(op2(x, y)) e1 = op3(op2(x, y))
e2 = op3(op2(x, y)) e2 = op3(op2(x, y))
g = env([x, y, z], [e1, e2]) g = Env([x, y, z], [e1, e2])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]" assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
...@@ -327,7 +323,7 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -327,7 +323,7 @@ class _test_MergeOptimizer(unittest.TestCase):
e1 = op1(x, y) e1 = op1(x, y)
e2 = op2(op3(x), y, z) e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2)) e = op1(e1, op4(e2, e1), op1(e2))
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if # note: graph.as_string can only produce the following two possibilities, but if
...@@ -336,84 +332,82 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -336,84 +332,82 @@ class _test_MergeOptimizer(unittest.TestCase):
or strg == "[Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1))]" or strg == "[Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1))]"
def test_identical_constant_args(self): def test_identical_constant_args(self):
x, y, z = inputs() x = MyResult('x')
y.data = 2.0 y = Constant(MyType(), 2, name = 'y')
y.constant = True z = Constant(MyType(), 2, name = 'z')
z.data = 2.0
z.constant = True
e1 = op1(y, z) e1 = op1(y, z)
g = env([x, y, z], [e1]) g = Env([x, y, z], [e1])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
self.failUnless(strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]', strg) self.failUnless(strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]', strg)
def test_identical_constant_args_with_destroymap(self): # def test_identical_constant_args_with_destroymap(self):
x, y, z = inputs() # x, y, z = inputs()
y.data = 2.0 # y.data = 2.0
y.constant = False # y.constant = False
z.data = 2.0 # z.data = 2.0
z.constant = True # z.constant = True
e1 = op_d(y, z) # e1 = op_d(y, z)
g = env([x, y, z], [e1]) # g = env([x, y, z], [e1])
MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
strg = str(g) # strg = str(g)
self.failUnless(strg == '[OpD(y, z)]', strg) # self.failUnless(strg == '[OpD(y, z)]', strg)
def test_merge_with_destroyer_1(self): # def test_merge_with_destroyer_1(self):
x, y, z = inputs() # x, y, z = inputs()
e1 = op_d(op1(x,y), y) # e1 = op_d(op1(x,y), y)
e2 = op_d(op1(x,y), z) # e2 = op_d(op1(x,y), z)
g = env([x, y, z], [e1,e2]) # g = env([x, y, z], [e1,e2])
MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
strg = str(g) # strg = str(g)
self.failUnless(strg == '[OpD(Op1(x, y), y), OpD(Op1(x, y), z)]', strg) # self.failUnless(strg == '[OpD(Op1(x, y), y), OpD(Op1(x, y), z)]', strg)
def test_merge_with_destroyer_2(self): # def test_merge_with_destroyer_2(self):
x, y, z = inputs() # x, y, z = inputs()
e1 = op_d(op1(x,y), z) # e1 = op_d(op1(x,y), z)
e2 = op_d(op1(x,y), z) # e2 = op_d(op1(x,y), z)
g = env([x, y, z], [e1,e2]) # g = env([x, y, z], [e1,e2])
MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
strg = str(g) # strg = str(g)
self.failUnless(strg == '[*1 -> OpD(Op1(x, y), z), *1]', strg) # self.failUnless(strg == '[*1 -> OpD(Op1(x, y), z), *1]', strg)
class _test_ConstantFinder(unittest.TestCase): # class _test_ConstantFinder(unittest.TestCase):
def test_straightforward(self): # def test_straightforward(self):
x, y, z = inputs() # x, y, z = inputs()
y.data = 2 # y.data = 2
z.data = 2 # z.data = 2
e = op1(x, y, z) # e = op1(x, y, z)
g = env([x], [e]) # g = env([x], [e])
ConstantFinder().optimize(g) # ConstantFinder().optimize(g)
assert y.constant and z.constant # assert y.constant and z.constant
MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
assert str(g) == "[Op1(x, y, y)]" \ # assert str(g) == "[Op1(x, y, y)]" \
or str(g) == "[Op1(x, z, z)]" # or str(g) == "[Op1(x, z, z)]"
def test_deep(self): # def test_deep(self):
x, y, z = inputs() # x, y, z = inputs()
y.data = 2 # y.data = 2
z.data = 2 # z.data = 2
e = op1(op2(x, y), op2(x, y), op2(x, z)) # e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x], [e]) # g = env([x], [e])
ConstantFinder().optimize(g) # ConstantFinder().optimize(g)
assert y.constant and z.constant # assert y.constant and z.constant
MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ # assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]" # or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_destroyed_orphan_not_constant(self): # def test_destroyed_orphan_not_constant(self):
x, y, z = inputs() # x, y, z = inputs()
y.data = 2 # y.data = 2
z.data = 2 # z.data = 2
e = op_d(x, op2(y, z)) # here x is destroyed by op_d # e = op_d(x, op2(y, z)) # here x is destroyed by op_d
g = env([y], [e]) # g = env([y], [e])
ConstantFinder().optimize(g) # ConstantFinder().optimize(g)
assert not getattr(x, 'constant', False) and z.constant # assert not getattr(x, 'constant', False) and z.constant
MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
......
import unittest
from result import *
class Double(Result):
def __init__(self, data, name = "oignon"):
Result.__init__(self, role = None, name = name)
assert isinstance(data, float)
self.data = data
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __copy__(self):
return Double(self.data, self.name)
class MyResult(Result):
def __init__(self, name):
Result.__init__(self, role = None, name = name)
self.data = [1000]
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __copy__(self):
return MyResult(self.name)
class _test_Result(unittest.TestCase):
def test_trivial(self):
r = Result()
def test_state(self):
r = Result()
assert r.state is Empty
r.data = 0
assert r.data == 0
assert r.state is Computed
r.data = 1
assert r.data == 1
assert r.state is Computed
r.data = None
assert r.data == None
assert r.state is Empty
if __name__ == '__main__':
unittest.main()
import unittest import unittest
from result import Result from graph import Result, as_result, Apply
from type import Type
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer #from opt import PatternOptimizer, OpSubOptimizer
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import * from toolbox import *
class MyType(Type):
class MyResult(Result):
def __init__(self, name): def __init__(self, name):
Result.__init__(self, role = None, name = name) self.name = name
self.data = [1000]
def __str__(self): def __str__(self):
return self.name return self.name
...@@ -22,40 +21,43 @@ class MyResult(Result): ...@@ -22,40 +21,43 @@ class MyResult(Result):
def __repr__(self): def __repr__(self):
return self.name return self.name
def __eq__(self, other):
return isinstance(other, MyType)
def MyResult(name):
return Result(MyType(name), None, None)
class MyOp(Op): class MyOp(Op):
nin = -1
def __init__(self, *inputs): def __init__(self, nin, name):
self.nin = nin
self.name = name
def make_node(self, *inputs):
assert len(inputs) == self.nin assert len(inputs) == self.nin
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input.type, MyType):
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [MyType(self.name + "_R")()]
self.outputs = [MyResult(self.__class__.__name__ + "_R")] return Apply(self, inputs, outputs)
class Sigmoid(MyOp): def __str__(self):
nin = 1 return self.name
class Add(MyOp):
nin = 2
class Dot(MyOp):
nin = 2
sigmoid = MyOp(1, 'Sigmoid')
add = MyOp(2, 'Add')
dot = MyOp(2, 'Dot')
import modes
modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.build(MyResult('x')) x = MyResult('x')
y = modes.build(MyResult('y')) y = MyResult('y')
z = modes.build(MyResult('z')) z = MyResult('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_EquivTool(unittest.TestCase): class _test_EquivTool(unittest.TestCase):
...@@ -63,38 +65,43 @@ class _test_EquivTool(unittest.TestCase): ...@@ -63,38 +65,43 @@ class _test_EquivTool(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
sx = sigmoid(x) sx = sigmoid(x)
e = add(sx, sigmoid(y)) e = add(sx, sigmoid(y))
g = env([x, y, z], [e], features = [EquivTool]) g = Env([x, y, z], [e])
g.extend(EquivTool(g))
assert hasattr(g, 'equiv')
assert g.equiv(sx) is sx assert g.equiv(sx) is sx
g.replace(sx, dot(x, z)) g.replace(sx, dot(x, z))
assert g.equiv(sx) is not sx assert g.equiv(sx) is not sx
assert isinstance(g.equiv(sx).owner, Dot) assert g.equiv(sx).owner.op is dot
class _test_InstanceFinder(unittest.TestCase): class _test_NodeFinder(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e0 = dot(y, z) e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = env([x, y, z], [e], features = [InstanceFinder]) g = Env([x, y, z], [e])
for type, num in ((Add, 3), (Sigmoid, 3), (Dot, 2)): g.extend(NodeFinder(g))
if not len([x for x in g.get_instances_of(type)]) == num: assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num)) self.fail((type, num))
new_e0 = add(y, z) new_e0 = add(y, z)
assert e0.owner in g.get_instances_of(Dot) assert e0.owner in g.get_nodes(dot)
assert new_e0.owner not in g.get_instances_of(Add) assert new_e0.owner not in g.get_nodes(add)
g.replace(e0, new_e0) g.replace(e0, new_e0)
assert e0.owner not in g.get_instances_of(Dot) assert e0.owner not in g.get_nodes(dot)
assert new_e0.owner in g.get_instances_of(Add) assert new_e0.owner in g.get_nodes(add)
for type, num in ((Add, 4), (Sigmoid, 3), (Dot, 1)): for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([x for x in g.get_instances_of(type)]) == num: if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num)) self.fail((type, num))
def test_robustness(self): def test_robustness(self):
x, y, z = inputs() x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z))) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = env([x, y, z], [e], features = [InstanceFinder]) g = Env([x, y, z], [e])
gen = g.get_instances_of(Sigmoid) # I want to get Sigmoid instances g.extend(NodeFinder(g))
gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them assert len([x for x in gen]) == 0 # the generator should not yield them
......
import unittest
from type import *
# todo: test generic
if __name__ == '__main__':
unittest.main()
import unittest, os, sys import unittest, os, sys, traceback
if __name__ == '__main__': if __name__ == '__main__':
suite = None suite = None
......
from link import Linker, raise_with_op from graph import Constant
from link import Linker, LocalLinker, raise_with_op, Filter, map_storage, PerformLinker
from copy import copy from copy import copy
from utils import AbstractFunctionError from utils import AbstractFunctionError
from env import Env
import md5 import md5
import sys import sys
import os import os
...@@ -244,26 +246,26 @@ def get_c_declare(r, name, sub): ...@@ -244,26 +246,26 @@ def get_c_declare(r, name, sub):
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" % locals() """ % locals()
return pre + r.c_declare(name, sub) return pre + r.type.c_declare(name, sub)
def get_c_init(r, name, sub): def get_c_init(r, name, sub):
pre = "" """ pre = "" """
py_%(name)s = Py_None; py_%(name)s = Py_None;
""" % locals() """ % locals()
return pre + r.c_init(name, sub) return pre + r.type.c_init(name, sub)
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
""" % locals() """ % locals()
return pre + r.c_extract(name, sub) return pre + r.type.c_extract(name, sub)
def get_c_cleanup(r, name, sub): def get_c_cleanup(r, name, sub):
post = """ post = """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
""" % locals() """ % locals()
return r.c_cleanup(name, sub) + post return r.type.c_cleanup(name, sub) + post
def get_c_sync(r, name, sub): def get_c_sync(r, name, sub):
return """ return """
...@@ -274,7 +276,7 @@ def get_c_sync(r, name, sub): ...@@ -274,7 +276,7 @@ def get_c_sync(r, name, sub):
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s); PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
Py_XDECREF(old); Py_XDECREF(old);
} }
""" % dict(sync = r.c_sync(name, sub), name = name, **sub) """ % dict(sync = r.type.c_sync(name, sub), name = name, **sub)
def apply_policy(policy, r, name, sub): def apply_policy(policy, r, name, sub):
""" """
...@@ -329,34 +331,25 @@ class CLinker(Linker): ...@@ -329,34 +331,25 @@ class CLinker(Linker):
It can take an env or an Op as input. It can take an env or an Op as input.
""" """
def __init__(self, env): def __init__(self, env, no_recycling = []):
self.env = env self.env = env
self.fetch_results() self.fetch_results()
self.no_recycling = no_recycling
def fetch_results(self): def fetch_results(self):
""" """
Fills the inputs, outputs, results, orphans, temps and op_order fields. Fills the inputs, outputs, results, orphans, temps and node_order fields.
""" """
env = self.env env = self.env
self.inputs = env.inputs self.inputs = env.inputs
self.outputs = env.outputs self.outputs = env.outputs
self.results = list(env.results)
try: self.results = list(env.results())
except AttributeError: self.results = self.inputs + self.outputs
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
try: self.orphans = list(env.orphans().difference(self.outputs)) self.orphans = list(env.orphans.difference(self.outputs))
except AttributeError: self.orphans = [] self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.node_order = env.toposort()
try: self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
except AttributeError: self.temps = []
try: self.op_order = env.toposort()
except AttributeError: self.op_order = [env]
def code_gen(self, do_not_reuse = []): # reuse_storage = True): def code_gen(self):
""" """
Generates code for a struct that does the computation of the env and Generates code for a struct that does the computation of the env and
stores it in the struct_code field of the instance. stores it in the struct_code field of the instance.
...@@ -370,9 +363,11 @@ class CLinker(Linker): ...@@ -370,9 +363,11 @@ class CLinker(Linker):
This method caches its computations. This method caches its computations.
""" """
if getattr(self, 'struct_code', False) and self.do_not_reuse == do_not_reuse: if getattr(self, 'struct_code', False):
return self.struct_code return self.struct_code
no_recycling = self.no_recycling
env = self.env env = self.env
consts = [] consts = []
...@@ -397,34 +392,33 @@ class CLinker(Linker): ...@@ -397,34 +392,33 @@ class CLinker(Linker):
for result in set(self.results): for result in set(self.results):
# it might be possible to inline constant results as C literals # it might be possible to inline constant results as C literals
if getattr(result, 'constant', False): ## if getattr(result, 'constant', False):
if result in self.outputs or result in self.temps:
raise Exception("Temporaries and outputs should not be marked constant. Check your graph.")
try:
symbol[result] = result.c_literal()
consts.append(result)
if result in self.inputs:
print "Warning: input %s is marked as constant and has been compiled as a literal." % result
elif result in self.orphans:
self.orphans.remove(result)
continue
except (AbstractFunctionError, NotImplementedError):
pass
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction], # policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]] # [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
if result in self.inputs: if result in self.inputs:
# we need to extract the new inputs at each run # we need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync
# if isinstance(result, Constant):
# raise TypeError("Inputs to CLinker cannot be Constant.", result)
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]] [get_c_declare, get_c_extract, get_c_cleanup]]
elif result in self.orphans: elif result in self.orphans:
if not isinstance(result, Constant):
raise TypeError("All orphans to CLinker must be Constant.", result)
try:
symbol[result] = "(" + result.type.c_literal(result.data) + ")"
consts.append(result)
self.orphans.remove(result)
continue
except (AbstractFunctionError, NotImplementedError):
pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same # orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup], policy = [[get_c_declare, get_c_extract, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif result in self.temps: elif result in self.temps:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract # temps don't need to be extracted from Python, so we call c_init rather than c_extract
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync
if result.c_is_simple() or result in do_not_reuse: if result.type.c_is_simple() or result in no_recycling:
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, get_c_cleanup]] [get_c_declare, get_c_init, get_c_cleanup]]
else: else:
...@@ -433,7 +427,7 @@ class CLinker(Linker): ...@@ -433,7 +427,7 @@ class CLinker(Linker):
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif result in self.outputs: elif result in self.outputs:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract # outputs don't need to be extracted from Python, so we call c_init rather than c_extract
if result.c_is_simple() or result in do_not_reuse: if result.type.c_is_simple() or result in no_recycling:
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]] [get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]]
...@@ -458,7 +452,7 @@ class CLinker(Linker): ...@@ -458,7 +452,7 @@ class CLinker(Linker):
id += 2 id += 2
for op in self.op_order: for node in self.node_order:
# We populate sub with a mapping from the variable names specified by the op's c_var_names # We populate sub with a mapping from the variable names specified by the op's c_var_names
# method to the actual variable names that we will use. # method to the actual variable names that we will use.
...@@ -467,36 +461,28 @@ class CLinker(Linker): ...@@ -467,36 +461,28 @@ class CLinker(Linker):
## for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames): ## for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
## sub[vname] = symbol[result] ## sub[vname] = symbol[result]
isyms, osyms = [symbol[r] for r in op.inputs], [symbol[r] for r in op.outputs] name = "<invalid_c_thing>"
isyms, osyms = [symbol[r] for r in node.inputs], [symbol[r] for r in node.outputs]
# Make the CodeBlock for c_validate_update
sub['id'] = id
sub['fail'] = failure_code(sub)
try: validate_behavior = op.c_validate_update(isyms, osyms, sub)
except AbstractFunctionError:
validate_behavior = ""
try: validate_cleanup = op.c_validate_update_cleanup(isyms, osyms, sub)
except AbstractFunctionError:
validate_cleanup = ""
blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub)) # c_validate_update is deprecated
tasks.append((op, 'validate_update', id)) if hasattr(node.op, 'c_validate_update'):
id += 1 raise Exception("c_validate_update is deprecated, move contents to c_code", node.op)
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
sub['id'] = id sub['id'] = id
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
behavior = op.c_code(isyms, osyms, sub) # this one must be implemented! op = node.op
try: behavior = op.c_code(node, name, isyms, osyms, sub)
except AbstractFunctionError:
raise NotImplementedError("%s cannot produce C code" % op)
try: cleanup = op.c_code_cleanup(isyms, osyms, sub) try: cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
cleanup = "" cleanup = ""
blocks.append(CodeBlock("", behavior, cleanup, sub)) blocks.append(CodeBlock("", behavior, cleanup, sub))
tasks.append((op, 'code', id)) tasks.append((node, 'code', id))
id += 1 id += 1
# List of arg names for use in struct_gen. Note the call to uniq: duplicate inputs # List of arg names for use in struct_gen. Note the call to uniq: duplicate inputs
...@@ -513,7 +499,6 @@ class CLinker(Linker): ...@@ -513,7 +499,6 @@ class CLinker(Linker):
struct_code %= dict(name = struct_name) struct_code %= dict(name = struct_name)
self.struct_code = struct_code self.struct_code = struct_code
self.do_not_reuse = do_not_reuse
self.struct_name = struct_name self.struct_name = struct_name
self.hash = hash self.hash = hash
self.args = args self.args = args
...@@ -550,7 +535,7 @@ class CLinker(Linker): ...@@ -550,7 +535,7 @@ class CLinker(Linker):
This might contain duplicates. This might contain duplicates.
""" """
ret = [] ret = []
for x in self.results + self.op_order: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret.append(x.c_support_code()) try: ret.append(x.c_support_code())
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
...@@ -563,7 +548,7 @@ class CLinker(Linker): ...@@ -563,7 +548,7 @@ class CLinker(Linker):
This might contain duplicates. This might contain duplicates.
""" """
ret = [] ret = []
for x in self.results + self.op_order: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_compile_args() try: ret += x.c_compile_args()
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
...@@ -576,7 +561,7 @@ class CLinker(Linker): ...@@ -576,7 +561,7 @@ class CLinker(Linker):
This might contain duplicates. This might contain duplicates.
""" """
ret = [] ret = []
for x in self.results + self.op_order: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_headers() try: ret += x.c_headers()
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
...@@ -589,36 +574,43 @@ class CLinker(Linker): ...@@ -589,36 +574,43 @@ class CLinker(Linker):
This might contain duplicates. This might contain duplicates.
""" """
ret = [] ret = []
for x in self.results + self.op_order: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_libraries() try: ret += x.c_libraries()
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
def __compile__(self, inplace = False): def __compile__(self, input_storage = None, output_storage = None):
""" """
@todo update
Compiles this linker's env. If inplace is True, it will use the Compiles this linker's env. If inplace is True, it will use the
Results contained in the env, if it is False it will copy the Results contained in the env, if it is False it will copy the
input and output Results. input and output Results.
Returns: thunk, in_results, out_results, error_storage Returns: thunk, in_results, out_results, error_storage
""" """
if inplace: # if inplace:
in_results = self.inputs # in_results = self.inputs
out_results = self.outputs # out_results = self.outputs
else: # else:
in_results = [copy(input) for input in self.inputs] # in_results = [copy(input) for input in self.inputs]
out_results = [copy(output) for output in self.outputs] # out_results = [copy(output) for output in self.outputs]
error_storage = [None, None, None] error_storage = [None, None, None]
if input_storage is None:
input_storage = [[None] for result in self.inputs]
if output_storage is None:
output_storage = [[None] for result in self.outputs]
thunk = self.cthunk_factory(error_storage, thunk = self.cthunk_factory(error_storage,
[result._data for result in in_results], input_storage,
[result._data for result in out_results]) output_storage)
if not inplace: return thunk, [Filter(input.type, storage) for input, storage in zip(self.env.inputs, input_storage)], \
for r in in_results + out_results: [Filter(output.type, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \
r._role = None # we just need the wrapper, not the (copied) graph associated to it error_storage
return thunk, in_results, out_results, error_storage
# return thunk, [Filter(x) for x in input_storage], [Filter(x) for x in output_storage], error_storage
def make_thunk(self, inplace = False):
cthunk, in_results, out_results, error_storage = self.__compile__(inplace) def make_thunk(self, input_storage = None, output_storage = None):
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
def execute(): def execute():
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
if failure: if failure:
...@@ -631,7 +623,7 @@ class CLinker(Linker): ...@@ -631,7 +623,7 @@ class CLinker(Linker):
exc_value = exc_type(_exc_value, task) exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
return execute, in_results, out_results return execute, in_storage, out_storage
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
""" """
...@@ -719,13 +711,13 @@ class CLinker(Linker): ...@@ -719,13 +711,13 @@ class CLinker(Linker):
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx] out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
ret = module.instantiate(error_storage, *(in_storage + out_storage + [orphan._data for orphan in self.orphans])) ret = module.instantiate(error_storage, *(in_storage + out_storage + [orphan.data for orphan in self.orphans]))
assert sys.getrefcount(ret) == 2 # refcount leak check assert sys.getrefcount(ret) == 2 # refcount leak check
return ret return ret
class OpWiseCLinker(Linker): class OpWiseCLinker(LocalLinker):
""" """
Uses CLinker on the individual Ops that comprise an env and loops Uses CLinker on the individual Ops that comprise an env and loops
over them in Python. The result is slower than a compiled version of over them in Python. The result is slower than a compiled version of
...@@ -737,46 +729,76 @@ class OpWiseCLinker(Linker): ...@@ -737,46 +729,76 @@ class OpWiseCLinker(Linker):
perform method if no C version can be generated. perform method if no C version can be generated.
""" """
def __init__(self, env, fallback_on_perform = True): def __init__(self, env, fallback_on_perform = True, no_recycling = []):
self.env = env self.env = env
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.no_recycling = no_recycling
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env
order = env.toposort()
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage)
def make_thunk(self, inplace = False, profiler = None):
if inplace:
env = self.env
else:
env = self.env.clone(True)
op_order = env.toposort()
inputs, outputs = env.inputs, env.outputs
env = None
thunks = [] thunks = []
for op in op_order: for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
try: try:
cl = CLinker(op) cl = CLinker(Env(node.inputs, node.outputs))
thunk, in_results, out_results = cl.make_thunk(True) thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunks.append(thunk) thunks.append(thunk)
except AbstractFunctionError: except (NotImplementedError, AbstractFunctionError):
if self.fallback_on_perform: if self.fallback_on_perform:
thunks.append(op.perform) p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk.perform = p
thunks.append(thunk)
else: else:
raise raise
if profiler is None:
def f():
try:
for thunk, op in zip(thunks, op_order):
thunk()
except:
raise_with_op(op)
else:
def f():
def g():
for thunk, op in zip(thunks, op_order):
profiler.profile_op(thunk, op)
profiler.profile_env(g, env)
f.profiler = profiler
return f, inputs, outputs if no_recycling is True:
no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
# if profiler is None:
# def f():
# for x in no_recycling:
# x[0] = None
# try:
# for thunk, node in zip(thunks, order):
# thunk()
# except:
# raise_with_op(node)
# else:
# def f():
# for x in no_recycling:
# x[0] = None
# def g():
# for thunk, node in zip(thunks, order):
# profiler.profile_op(thunk, node)
# profiler.profile_env(g, env)
# f.profiler = profiler
return f, [Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order
...@@ -786,8 +808,8 @@ def _default_checker(x, y): ...@@ -786,8 +808,8 @@ def _default_checker(x, y):
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
results contain the same data using ==. results contain the same data using ==.
""" """
if x.data != y.data: if x[0] != y[0]:
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data}) raise Exception("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class DualLinker(Linker): class DualLinker(Linker):
""" """
...@@ -823,38 +845,42 @@ class DualLinker(Linker): ...@@ -823,38 +845,42 @@ class DualLinker(Linker):
self.env = env self.env = env
self.checker = checker self.checker = checker
def make_thunk(self, inplace = False): def make_thunk(self, **kwargs):
if inplace: # if inplace:
env1 = self.env # env1 = self.env
else: # else:
env1 = self.env.clone(True) # env1 = self.env.clone(True)
env2, equiv = env1.clone_get_equiv(True) # env2, equiv = env1.clone_get_equiv(True)
# op_order_1 = env1.toposort()
# op_order_2 = [equiv[op.outputs[0]].owner for op in op_order_1] # we need to have the exact same order so we can compare each step
# def c_make_thunk(op):
# try:
# return CLinker(op).make_thunk(True)[0]
# except AbstractFunctionError:
# return op.perform
# thunks1 = [op.perform for op in op_order_1]
# thunks2 = [c_make_thunk(op) for op in op_order_2]
op_order_1 = env1.toposort() env = self.env
op_order_2 = [equiv[op.outputs[0]].owner for op in op_order_1] # we need to have the exact same order so we can compare each step _f, i1, o1, thunks1, order1 = PerformLinker(env).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env).make_all(**kwargs)
def c_make_thunk(op):
try:
return CLinker(op).make_thunk(True)[0]
except AbstractFunctionError:
return op.perform
thunks1 = [op.perform for op in op_order_1]
thunks2 = [c_make_thunk(op) for op in op_order_2]
def f(): def f():
for input1, input2 in zip(env1.inputs, env2.inputs): for input1, input2 in zip(i1, i2):
# set the inputs to be the same in both branches # set the inputs to be the same in both branches
# the copy is necessary in order for inplace ops not to interfere # the copy is necessary in order for inplace ops not to interfere
input2.data = copy(input1.data) input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, op1, op2 in zip(thunks1, thunks2, op_order_1, op_order_2): for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
try: try:
thunk1() thunk1()
thunk2() thunk2()
for output1, output2 in zip(op1.outputs, op2.outputs): for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
self.checker(output1, output2) self.checker(output1, output2)
except: except:
raise_with_op(op1) raise_with_op(node1)
# exc_type, exc_value, exc_trace = sys.exc_info() # exc_type, exc_value, exc_trace = sys.exc_info()
# try: # try:
# trace = op1.trace # trace = op1.trace
...@@ -864,7 +890,7 @@ class DualLinker(Linker): ...@@ -864,7 +890,7 @@ class DualLinker(Linker):
# exc_value.args = exc_value.args + (op1, ) # exc_value.args = exc_value.args + (op1, )
# raise exc_type, exc_value, exc_trace # raise exc_type, exc_value, exc_trace
return f, env1.inputs, env1.outputs return f, i1, o1
......
...@@ -6,9 +6,6 @@ from features import Listener, Orderings, Constraint, Tool, uniq_features ...@@ -6,9 +6,6 @@ from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils import utils
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['InconsistencyError',
'Env']
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
...@@ -18,22 +15,12 @@ class InconsistencyError(Exception): ...@@ -18,22 +15,12 @@ class InconsistencyError(Exception):
pass pass
def require_set(cls): def require_set(x):
"""Return the set of objects named in a __env_require__ field in a base class""" try:
r = set() req = x.env_require
except AttributeError:
if hasattr(cls, '__class__'): req = []
cls = cls.__class__ return req
bases = utils.all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
for base in bases:
req = base.__env_require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
class Env(graph.Graph): class Env(graph.Graph):
...@@ -63,39 +50,41 @@ class Env(graph.Graph): ...@@ -63,39 +50,41 @@ class Env(graph.Graph):
### Special ### ### Special ###
def __init__(self, inputs, outputs, features = [], consistency_check = True): # **listeners): def __init__(self, inputs, outputs): #, consistency_check = True):
""" """
Create an Env which operates on the subgraph bound by the inputs and outputs Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated. sets. If consistency_check is False, an illegal graph will be tolerated.
""" """
self._features = {} # self._features = {}
self._listeners = {} # self._listeners = {}
self._constraints = {} # self._constraints = {}
self._orderings = {} # self._orderings = {}
self._tools = {} # self._tools = {}
self._features = []
# The inputs and outputs set bound the subgraph this Env operates on. # The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = list(inputs) self.inputs = list(inputs)
self.outputs = list(outputs) self.outputs = list(outputs)
# All ops in the subgraph defined by inputs and outputs are cached in _ops # All nodes in the subgraph defined by inputs and outputs are cached in nodes
self._ops = set() self.nodes = set()
# Ditto for results # Ditto for results
self._results = set(self.inputs) self.results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but # Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph. # are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),) # e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
# We initialize them to the set of outputs; if an output depends on an input, # We initialize them to the set of outputs; if an output depends on an input,
# it will be removed from the set of orphans. # it will be removed from the set of orphans.
self._orphans = set(outputs).difference(inputs) self.orphans = set(outputs).difference(inputs)
for feature_class in uniq_features(features): # for feature_class in uniq_features(features):
self.add_feature(feature_class, False) # self.add_feature(feature_class, False)
# Maps results to ops that use them: # Maps results to nodes that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v] # if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {} self._clients = {}
...@@ -104,11 +93,11 @@ class Env(graph.Graph): ...@@ -104,11 +93,11 @@ class Env(graph.Graph):
self.history = [] self.history = []
self.__import_r__(self.outputs) self.__import_r__(self.outputs)
for op in self.ops(): # for op in self.nodes():
self.satisfy(op) # self.satisfy(op)
if consistency_check: # if consistency_check:
self.validate() # self.validate()
### Public interface ### ### Public interface ###
...@@ -141,97 +130,87 @@ class Env(graph.Graph): ...@@ -141,97 +130,87 @@ class Env(graph.Graph):
return False return False
return True return True
def satisfy(self, x): # def satisfy(self, x):
"Adds the features required by x unless they are already present." # "Adds the features required by x unless they are already present."
for feature_class in require_set(x): # for feature_class in require_set(x):
self.add_feature(feature_class) # self.add_feature(feature_class)
def add_feature(self, feature_class, do_import = True): def extend(self, feature, do_import = True, validate = False):
""" """
@todo out of date
Adds an instance of the feature_class to this env's supported Adds an instance of the feature_class to this env's supported
features. If do_import is True and feature_class is a subclass features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Ops of Listener, its on_import method will be called on all the Nodes
already in the env. already in the env.
""" """
if feature_class in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
self.__add_feature__(feature_class, do_import) self.__add_feature__(feature, do_import)
if validate:
def __add_feature__(self, feature_class, do_import): self.validate()
if not issubclass(feature_class, (Listener, Constraint, Orderings, Tool)):
raise TypeError("features must be subclasses of Listener, Constraint, Orderings and/or Tools",
(feature_class,type(feature_class)))
feature = feature_class(self)
if issubclass(feature_class, Listener):
self._listeners[feature_class] = feature
if do_import:
for op in self.io_toposort():
try:
feature.on_import(op)
except AbstractFunctionError:
pass
if issubclass(feature_class, Constraint):
self._constraints[feature_class] = feature
if issubclass(feature_class, Orderings):
self._orderings[feature_class] = feature
if issubclass(feature_class, Tool):
self._tools[feature_class] = feature
feature.publish()
self._features[feature_class] = feature
def __del_feature__(self, feature_class):
for set in [self._features, self._constraints, self._orderings, self._tools, self._listeners]:
try:
del set[feature_class]
except KeyError:
pass
def get_feature(self, feature_class): def execute_callbacks(self, name, *args):
return self._features[feature_class] for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
fn(*args)
def __add_feature__(self, feature, do_import):
self._features.append(feature)
publish = getattr(feature, 'publish', None)
if publish is not None:
publish()
if do_import:
try:
fn = feature.on_import
except AttributeError:
return
for node in self.io_toposort():
fn(node)
def has_feature(self, feature_class): def __del_feature__(self, feature):
try: try:
self.get_feature(feature_class) del self._features[feature]
except: except:
return False pass
return True unpublish = hasattr(feature, 'unpublish')
if unpublish is not None:
unpublish()
def get_feature(self, feature):
idx = self._features.index(feature)
return self._features[idx]
def has_feature(self, feature):
return feature in self._features
def nclients(self, r): def nclients(self, r):
"Same as len(self.clients(r))." "Same as len(self.clients(r))."
return len(self.clients(r)) return len(self.clients(r))
def edge(self, r): def edge(self, r):
return r in self.inputs or r in self.orphans() return r in self.inputs or r in self.orphans
def follow(self, r): def follow(self, r):
op = r.owner node = r.owner
if self.edge(r): if self.edge(r):
return None return None
else: else:
if op is None: if node is None:
raise Exception("what the fuck") raise Exception("what the fuck")
return op.inputs return node.inputs
def ops(self):
"All ops within the subgraph bound by env.inputs and env.outputs."
return self._ops
def has_op(self, op):
return op in self._ops
def orphans(self): def has_node(self, node):
""" return node in self.nodes
All results not within the subgraph bound by env.inputs and
env.outputs, not in env.inputs but required by some op.
"""
return self._orphans
def replace(self, r, new_r, consistency_check = True): def replace(self, r, new_r, consistency_check = True):
""" """
This is the main interface to manipulate the subgraph in Env. This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead. For every op that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type This may raise an error if the new result violates type
constraints for one of the target ops. In that case, no constraints for one of the target nodes. In that case, no
changes are made. changes are made.
If the replacement makes the graph inconsistent and the value If the replacement makes the graph inconsistent and the value
...@@ -243,9 +222,8 @@ class Env(graph.Graph): ...@@ -243,9 +222,8 @@ class Env(graph.Graph):
even if there is an inconsistency, unless the replacement even if there is an inconsistency, unless the replacement
violates hard constraints on the types involved. violates hard constraints on the types involved.
""" """
assert r in self.results
self.__import_r_satisfy__([new_r])
# Save where we are so we can backtrack # Save where we are so we can backtrack
if consistency_check: if consistency_check:
chk = self.checkpoint() chk = self.checkpoint()
...@@ -257,13 +235,15 @@ class Env(graph.Graph): ...@@ -257,13 +235,15 @@ class Env(graph.Graph):
# result. Note that if v is an input result, we do nothing at all for # result. Note that if v is an input result, we do nothing at all for
# now (it's not clear what it means to replace an input result). # now (it's not clear what it means to replace an input result).
was_output = False was_output = False
new_was_output = False
if new_r in self.outputs:
new_was_output = True
if r in self.outputs: if r in self.outputs:
was_output = True was_output = True
self.outputs[self.outputs.index(r)] = new_r self.outputs[self.outputs.index(r)] = new_r
was_input = False
if r in self.inputs:
was_input = True
self.inputs[self.inputs.index(r)] = new_r
# The actual replacement operation occurs here. This might raise # The actual replacement operation occurs here. This might raise
# an error. # an error.
self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
...@@ -272,8 +252,11 @@ class Env(graph.Graph): ...@@ -272,8 +252,11 @@ class Env(graph.Graph):
def undo(): def undo():
# Restore self.outputs # Restore self.outputs
if was_output: if was_output:
if not new_was_output: self.outputs[self.outputs.index(new_r)] = r
self.outputs[self.outputs.index(new_r)] = r
# Restore self.inputs
if was_input:
self.inputs[self.inputs.index(new_r)] = r
# Move back the clients. This should never raise an error. # Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r) self.__move_clients__(clients, new_r, r)
...@@ -307,13 +290,6 @@ class Env(graph.Graph): ...@@ -307,13 +290,6 @@ class Env(graph.Graph):
self.revert(chk) self.revert(chk)
raise raise
def results(self):
"""
All results within the subgraph bound by env.inputs and
env.outputs and including them
"""
return self._results
def revert(self, checkpoint): def revert(self, checkpoint):
""" """
Reverts the graph to whatever it was at the provided Reverts the graph to whatever it was at the provided
...@@ -332,14 +308,15 @@ class Env(graph.Graph): ...@@ -332,14 +308,15 @@ class Env(graph.Graph):
relationships). relationships).
""" """
ords = {} ords = {}
for ordering in self._orderings.values(): for feature in self._features:
for op, prereqs in ordering.orderings().items(): if hasattr(feature, 'orderings'):
ords.setdefault(op, set()).update(prereqs) for op, prereqs in feature.orderings().items():
ords.setdefault(op, set()).update(prereqs)
return ords return ords
def toposort(self): def toposort(self):
""" """
Returns a list of ops in the order that they must be executed Returns a list of nodes in the order that they must be executed
in order to preserve the semantics of the graph and respect in order to preserve the semantics of the graph and respect
the constraints put forward by the listeners. the constraints put forward by the listeners.
""" """
...@@ -351,8 +328,9 @@ class Env(graph.Graph): ...@@ -351,8 +328,9 @@ class Env(graph.Graph):
""" """
Raises an error if the graph is inconsistent. Raises an error if the graph is inconsistent.
""" """
for constraint in self._constraints.values(): self.execute_callbacks('validate')
constraint.validate() # for constraint in self._constraints.values():
# constraint.validate()
return True return True
...@@ -379,12 +357,12 @@ class Env(graph.Graph): ...@@ -379,12 +357,12 @@ class Env(graph.Graph):
self._clients[r].difference_update(all) self._clients[r].difference_update(all)
if not self._clients[r]: if not self._clients[r]:
del self._clients[r] del self._clients[r]
if r in self._orphans: if r in self.orphans:
self._orphans.remove(r) self.orphans.remove(r)
def __import_r_satisfy__(self, results): def __import_r_satisfy__(self, results):
# Satisfies the owners of the results. # Satisfies the owners of the results.
for op in graph.ops(self.results(), results): for op in graph.ops(self.results, results):
self.satisfy(op) self.satisfy(op)
def __import_r__(self, results): def __import_r__(self, results):
...@@ -393,35 +371,39 @@ class Env(graph.Graph): ...@@ -393,35 +371,39 @@ class Env(graph.Graph):
owner = result.owner owner = result.owner
if owner: if owner:
self.__import__(result.owner) self.__import__(result.owner)
if result not in self.results:
self.results.add(result)
self.orphans.add(result)
def __import__(self, op): def __import__(self, op):
# We import the ops in topological order. We only are interested # We import the nodes in topological order. We only are interested
# in new ops, so we use all results we know of as if they were the input set. # in new nodes, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
# know where to stop going down) # know where to stop going down)
new_ops = graph.io_toposort(self.results().difference(self.orphans()), op.outputs) new_nodes = graph.io_toposort(self.results.difference(self.orphans), op.outputs)
for op in new_ops: for op in new_nodes:
self._ops.add(op) self.nodes.add(op)
self._results.update(op.outputs) self.results.update(op.outputs)
self._orphans.difference_update(op.outputs) self.orphans.difference_update(op.outputs)
for i, input in enumerate(op.inputs): for i, input in enumerate(op.inputs):
self.__add_clients__(input, [(op, i)]) self.__add_clients__(input, [(op, i)])
if input not in self._results: if input not in self.results:
# This input is an orphan because if the op that # This input is an orphan because if the op that
# produced it was in the subgraph, io_toposort # produced it was in the subgraph, io_toposort
# would have placed it before, so we would have # would have placed it before, so we would have
# seen it (or it would already be in the graph) # seen it (or it would already be in the graph)
self._orphans.add(input) self.orphans.add(input)
self._results.add(input) self.results.add(input)
for listener in self._listeners.values(): self.execute_callbacks('on_import', op)
try: # for listener in self._listeners.values():
listener.on_import(op) # try:
except AbstractFunctionError: # listener.on_import(op)
pass # except AbstractFunctionError:
# pass
__import__.E_output = 'op output in Env.inputs' __import__.E_output = 'op output in Env.inputs'
def __prune_r__(self, results): def __prune_r__(self, results):
...@@ -432,6 +414,10 @@ class Env(graph.Graph): ...@@ -432,6 +414,10 @@ class Env(graph.Graph):
owner = result.owner owner = result.owner
if owner: if owner:
self.__prune__(owner) self.__prune__(owner)
# if result in self.results:
# self.results.remove(result)
# if result in self.orphans:
# self.orphans.remove(result)
def __prune__(self, op): def __prune__(self, op):
# If op's outputs have no clients, removes it from the graph # If op's outputs have no clients, removes it from the graph
...@@ -442,35 +428,41 @@ class Env(graph.Graph): ...@@ -442,35 +428,41 @@ class Env(graph.Graph):
# Cannot prune an op which is an output or used somewhere # Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output): if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return return
if op not in self._ops: # this can happen from replacing an orphan if op not in self.nodes: # this can happen from replacing an orphan
return return
self._ops.remove(op) self.nodes.remove(op)
self._results.difference_update(op.outputs) self.results.difference_update(op.outputs)
for listener in self._listeners.values(): self.execute_callbacks('on_prune', op)
try: # for listener in self._listeners.values():
listener.on_prune(op) # try:
except AbstractFunctionError: # listener.on_prune(op)
pass # except AbstractFunctionError:
# pass
for i, input in enumerate(op.inputs): for i, input in enumerate(op.inputs):
self.__remove_clients__(input, [(op, i)]) self.__remove_clients__(input, [(op, i)])
self.__prune_r__(op.inputs) self.__prune_r__(op.inputs)
def __move_clients__(self, clients, r, new_r): def __move_clients__(self, clients, r, new_r):
if not (r.type == new_r.type):
raise TypeError("Cannot move clients between Results that have different types.", r, new_r)
# We import the new result in the fold # We import the new result in the fold
self.__import_r__([new_r]) self.__import_r__([new_r])
try: for op, i in clients:
# Try replacing the inputs op.inputs[i] = new_r
for op, i in clients: # try:
op.set_input(i, new_r) # # Try replacing the inputs
except: # for op, i in clients:
# Oops! # op.set_input(i, new_r)
for op, i in clients: # except:
op.set_input(i, r) # # Oops!
self.__prune_r__([new_r]) # for op, i in clients:
raise # op.set_input(i, r)
# self.__prune_r__([new_r])
# raise
self.__remove_clients__(r, clients) self.__remove_clients__(r, clients)
self.__add_clients__(new_r, clients) self.__add_clients__(new_r, clients)
...@@ -478,12 +470,13 @@ class Env(graph.Graph): ...@@ -478,12 +470,13 @@ class Env(graph.Graph):
# # why was this line AFTER the set_inputs??? # # why was this line AFTER the set_inputs???
# # if we do it here then satisfy in import fucks up... # # if we do it here then satisfy in import fucks up...
# self.__import_r__([new_r]) # self.__import_r__([new_r])
for listener in self._listeners.values(): self.execute_callbacks('on_rewire', clients, r, new_r)
try: # for listener in self._listeners.values():
listener.on_rewire(clients, r, new_r) # try:
except AbstractFunctionError: # listener.on_rewire(clients, r, new_r)
pass # except AbstractFunctionError:
# pass
# We try to get rid of the old one # We try to get rid of the old one
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -494,17 +487,17 @@ class Env(graph.Graph): ...@@ -494,17 +487,17 @@ class Env(graph.Graph):
def clone_get_equiv(self, clone_inputs = True): def clone_get_equiv(self, clone_inputs = True):
equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
new = self.__class__([equiv[input] for input in self.inputs], new = self.__class__([equiv[input] for input in self.inputs],
[equiv[output] for output in self.outputs], [equiv[output] for output in self.outputs])
self._features.keys(), for feature in self._features:
consistency_check = False) new.extend(feature)
return new, equiv return new, equiv
def clone(self, clone_inputs = True): def clone(self, clone_inputs = True):
equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
new = self.__class__([equiv[input] for input in self.inputs], new = self.__class__([equiv[input] for input in self.inputs],
[equiv[output] for output in self.outputs], [equiv[output] for output in self.outputs])
self._features.keys(), for feature in self._features:
consistency_check = False) new.extend(feature)
try: try:
new.set_equiv(equiv) new.set_equiv(equiv)
except AttributeError: except AttributeError:
......
...@@ -69,7 +69,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -69,7 +69,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.seen = set() self.seen = set()
# Initialize the children if the inputs and orphans. # Initialize the children if the inputs and orphans.
for input in env.orphans().union(env.inputs): for input in env.orphans.union(env.inputs):
self.children[input] = set() self.children[input] = set()
def publish(self): def publish(self):
...@@ -197,17 +197,23 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -197,17 +197,23 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
for user in users: for user in users:
self.__detect_cycles_helper__(user, []) self.__detect_cycles_helper__(user, [])
def get_maps(self, op): def get_maps(self, node):
""" """
@return: (vmap, dmap) where: @return: (vmap, dmap) where:
- vmap -> {output : [inputs output is a view of]} - vmap -> {output : [inputs output is a view of]}
- dmap -> {output : [inputs that are destroyed by the Op - dmap -> {output : [inputs that are destroyed by the Op
(and presumably returned as that output)]} (and presumably returned as that output)]}
""" """
try: vmap = op.view_map() try: _vmap = node.op.view_map
except AttributeError, AbstractFunctionError: vmap = {} except AttributeError, AbstractFunctionError: _vmap = {}
try: dmap = op.destroy_map() try: _dmap = node.op.destroy_map
except AttributeError, AbstractFunctionError: dmap = {} except AttributeError, AbstractFunctionError: _dmap = {}
vmap = {}
for oidx, iidxs in _vmap.items():
vmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
dmap = {}
for oidx, iidxs in _dmap.items():
dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
return vmap, dmap return vmap, dmap
def on_import(self, op): def on_import(self, op):
...@@ -395,6 +401,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -395,6 +401,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
# Recompute the cycles from both r_1 and r_2. # Recompute the cycles from both r_1 and r_2.
self.__detect_cycles__(r_1) # we should really just remove the cycles that have r_1 and a result in prev just before self.__detect_cycles__(r_1) # we should really just remove the cycles that have r_1 and a result in prev just before
self.children.setdefault(r_2, set())
self.__detect_cycles__(r_2) self.__detect_cycles__(r_2)
def validate(self): def validate(self):
......
...@@ -2,21 +2,124 @@ ...@@ -2,21 +2,124 @@
from copy import copy from copy import copy
import utils import utils
from utils import object2
class Apply(object2):
#__slots__ = ['op', 'inputs', 'outputs']
def __init__(self, op, inputs, outputs):
self.op = op
self.inputs = []
for input in inputs:
if isinstance(input, Result):
self.inputs.append(input)
# elif isinstance(input, Type):
# self.inputs.append(Result(input, None, None))
else:
raise TypeError("The 'inputs' argument to Apply must contain Result instances, not %s" % input)
self.outputs = []
for i, output in enumerate(outputs):
if isinstance(output, Result):
if output.owner is None:
output.owner = self
output.index = i
elif output.owner is not self or output.index != i:
raise ValueError("All output results passed to Apply must belong to it.")
self.outputs.append(output)
# elif isinstance(output, Type):
# self.outputs.append(Result(output, self, i))
else:
raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output)
def default_output(self):
"""
Returns the default output for this Node, typically self.outputs[0].
Depends on the value of node.op.default_output
"""
do = self.op.default_output
if do < 0:
raise AttributeError("%s does not have a default output." % self.op)
elif do > len(self.outputs):
raise AttributeError("default output for %s is out of range." % self.op)
return self.outputs[do]
out = property(default_output,
doc = "Same as self.outputs[0] if this Op's has_default_output field is True.")
def __str__(self):
return op_as_string(self.inputs, self)
def __repr__(self):
return str(self)
def __asapply__(self):
return self
nin = property(lambda self: len(self.inputs))
nout = property(lambda self: len(self.outputs))
class Result(object2):
#__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None):
self.type = type
self.owner = owner
self.index = index
self.name = name
def __str__(self):
__all__ = ['inputs', if self.name is not None:
'results_and_orphans', 'results', 'orphans', return self.name
'ops', if self.owner is not None:
'clone', 'clone_get_equiv', op = self.owner.op
'io_toposort', if self.index == op.default_output:
'default_leaf_formatter', 'default_node_formatter', return str(self.owner.op) + ".out"
'op_as_string', else:
'as_string', return str(self.owner.op) + "." + str(self.index)
'Graph'] else:
return "?::" + str(self.type)
def __repr__(self):
return str(self)
def __asresult__(self):
return self
class Constant(Result):
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Result.__init__(self, type, None, None, name)
self.data = type.filter(data)
self.indestructible = True
def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
def signature(self):
return (self.type, self.data)
def __str__(self):
if self.name is not None:
return self.name
return str(self.data) #+ "::" + str(self.type)
def as_result(x):
if isinstance(x, Result):
return x
# elif isinstance(x, Type):
# return Result(x, None, None)
elif hasattr(x, '__asresult__'):
r = x.__asresult__()
if not isinstance(r, Result):
raise TypeError("%s.__asresult__ must return a Result instance" % x, (x, r))
return r
else:
raise TypeError("Cannot wrap %s in a Result" % x)
def as_apply(x):
if isinstance(x, Apply):
return x
elif hasattr(x, '__asapply__'):
node = x.__asapply__()
if not isinstance(node, Apply):
raise TypeError("%s.__asapply__ must return an Apply instance" % x, (x, node))
return node
else:
raise TypeError("Cannot map %s to Apply" % x)
is_result = utils.attr_checker('owner', 'index')
is_op = utils.attr_checker('inputs', 'outputs')
def inputs(o): def inputs(o):
...@@ -177,24 +280,39 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False): ...@@ -177,24 +280,39 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
for input in i: for input in i:
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
d[input] = copy(input) cpy = copy(input)
cpy.owner = None
cpy.index = None
d[input] = cpy
else: else:
d[input] = input d[input] = input
def clone_helper(result): def clone_helper(result):
if result in d: if result in d:
return d[result] return d[result]
op = result.owner node = result.owner
if not op: # result is an orphan if node is None: # result is an orphan
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
d[result] = copy(result) cpy = copy(result)
cpy.owner = None
cpy.index = None
d[result] = cpy
else: else:
d[result] = result d[result] = result
return d[result] return d[result]
else: else:
new_op = op.clone_with_new_inputs(*[clone_helper(input) for input in op.inputs]) new_node = copy(node)
d[op] = new_op new_node.inputs = [clone_helper(input) for input in node.inputs]
for output, new_output in zip(op.outputs, new_op.outputs): new_node.outputs = []
for output in node.outputs:
new_output = copy(output)
new_output.owner = new_node
new_node.outputs.append(new_output)
# new_node = Apply(node.op,
# [clone_helper(input) for input in node.inputs],
# [output.type for output in node.outputs])
d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output d[output] = new_output
return d[result] return d[result]
...@@ -203,6 +321,36 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False): ...@@ -203,6 +321,36 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
return d return d
# d = {}
# for input in i:
# if copy_inputs_and_orphans:
# d[input] = copy(input)
# else:
# d[input] = input
# def clone_helper(result):
# if result in d:
# return d[result]
# op = result.owner
# if not op: # result is an orphan
# if copy_inputs_and_orphans:
# d[result] = copy(result)
# else:
# d[result] = result
# return d[result]
# else:
# new_op = op.clone_with_new_inputs(*[clone_helper(input) for input in op.inputs])
# d[op] = new_op
# for output, new_output in zip(op.outputs, new_op.outputs):
# d[output] = new_output
# return d[result]
# for output in o:
# clone_helper(output)
# return d
def io_toposort(i, o, orderings = {}): def io_toposort(i, o, orderings = {}):
""" """
...@@ -231,7 +379,7 @@ def io_toposort(i, o, orderings = {}): ...@@ -231,7 +379,7 @@ def io_toposort(i, o, orderings = {}):
default_leaf_formatter = str default_leaf_formatter = str
default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.strdesc(), default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
", ".join(argstrings)) ", ".join(argstrings))
def op_as_string(i, op, def op_as_string(i, op,
...@@ -291,7 +439,7 @@ def as_string(i, o, ...@@ -291,7 +439,7 @@ def as_string(i, o,
if r.owner is not None and r not in i and r not in orph: if r.owner is not None and r not in i and r not in orph:
op = r.owner op = r.owner
idx = op.outputs.index(r) idx = op.outputs.index(r)
if idx == op._default_output_idx: if idx == op.op.default_output:
idxs = "" idxs = ""
else: else:
idxs = "::%i" % idx idxs = "::%i" % idx
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
from utils import AbstractFunctionError from utils import AbstractFunctionError
import utils import utils
from graph import Constant
import sys import sys
import traceback import traceback
...@@ -46,7 +48,7 @@ def raise_with_op(op, exc_info = None): ...@@ -46,7 +48,7 @@ def raise_with_op(op, exc_info = None):
class Linker: class Linker:
def make_thunk(self, inplace = False): def make_thunk(self):
""" """
This function must return a triplet (function, input_results, output_results) This function must return a triplet (function, input_results, output_results)
where function is a thunk that operates on the returned results. If inplace where function is a thunk that operates on the returned results. If inplace
...@@ -55,6 +57,7 @@ class Linker: ...@@ -55,6 +57,7 @@ class Linker:
results will be returned. results will be returned.
Example:: Example::
x, y = Result(Double), Result(Double)
e = x + y e = x + y
env = Env([x, y], [e]) env = Env([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(env).make_thunk(inplace) fn, (new_x, new_y), (new_e, ) = MyLinker(env).make_thunk(inplace)
...@@ -66,7 +69,7 @@ class Linker: ...@@ -66,7 +69,7 @@ class Linker:
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def make_function(self, inplace = False, unpack_single = True, **kwargs): def make_function(self, unpack_single = True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
env used by this L{Linker} and returns values corresponding the the outputs env used by this L{Linker} and returns values corresponding the the outputs
...@@ -85,8 +88,7 @@ class Linker: ...@@ -85,8 +88,7 @@ class Linker:
output, then that output will be returned. Else, a list or tuple of output, then that output will be returned. Else, a list or tuple of
length 1 will be returned. length 1 will be returned.
""" """
thunk, inputs, outputs = self.make_thunk(inplace, **kwargs) thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args): def execute(*args):
def e_arity(takes, got): def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \ return 'Function call takes exactly %i %s (%i given)' \
...@@ -107,9 +109,78 @@ class Linker: ...@@ -107,9 +109,78 @@ class Linker:
return execute return execute
class Filter(object):
def __init__(self, type, storage, readonly = False):
self.type = type
self.storage = storage
self.readonly = readonly
def __get(self):
return self.storage[0]
def __set(self, value):
if self.readonly:
raise Exception("Cannot set readonly storage.")
self.storage[0] = self.type.filter(value)
data = property(__get, __set)
def __str__(self):
return "<" + str(self.storage[0]) + ">"
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def map_storage(env, order, input_storage, output_storage):
if input_storage is None:
input_storage = [[None] for input in env.inputs]
else:
assert len(env.inputs) == len(input_storage)
storage_map = {}
for r, storage in zip(env.inputs, input_storage):
storage_map[r] = storage
for orphan in env.orphans:
if not isinstance(orphan, Constant):
raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
storage_map[orphan] = [orphan.data]
if output_storage is not None:
assert len(env.outputs) == len(output_storage)
for r, storage in zip(env.outputs, output_storage):
storage_map[r] = storage
thunks = []
for node in order:
for r in node.outputs:
storage_map.setdefault(r, [None])
if output_storage is None:
output_storage = [storage_map[r] for r in env.outputs]
return input_storage, output_storage, storage_map
class LocalLinker(Linker):
def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except:
raise_with_op(node)
else:
def f():
for x in no_recycling:
x[0] = None
def g():
for thunk, node in zip(thunks, order):
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f
class PerformLinker(Linker): class PerformLinker(LocalLinker):
""" """
Basic L{Linker} subclass that calls the perform method on each L{Op} in Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{Env} in the order given by L{Env.toposort}. the L{Env} in the order given by L{Env.toposort}.
...@@ -119,38 +190,107 @@ class PerformLinker(Linker): ...@@ -119,38 +190,107 @@ class PerformLinker(Linker):
self.env = env self.env = env
self.no_recycling = no_recycling self.no_recycling = no_recycling
def make_thunk(self, inplace = False, profiler = None): def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
if inplace: return self.make_all(profiler = profiler,
env = self.env input_storage = input_storage,
else: output_storage = output_storage)[:3]
env = self.env.clone(True)
def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env
order = env.toposort() order = env.toposort()
thunks = [op.perform for op in order]
no_recycling = self.no_recycling no_recycling = self.no_recycling
# input_storage = [[None] for input in env.inputs]
# output_storage = [[None] for output in env.outputs]
# storage_map = {}
# for r, storage in zip(env.inputs, input_storage):
# storage_map[r] = storage
# for orphan in env.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# thunks = []
# for node in order:
# node_input_storage = [storage_map[input] for input in node.inputs]
# node_output_storage = [storage_map.setdefault(r, [None]) for r in node.outputs]
# p = node.op.perform
# thunks.append(lambda p = p, i = node_input_storage, o = node_output_storage: p([x[0] for x in i], o))
# output_storage = [storage_map[r] for r in env.outputs]
thunks = []
input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage)
for node in order:
node_input_storage = [storage_map[input] for input in node.inputs]
node_output_storage = [storage_map[output] for output in node.outputs]
p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk.perform = p
thunks.append(thunk)
if no_recycling is True: if no_recycling is True:
no_recycling = list(env.results()) no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, env.inputs) no_recycling = utils.difference(no_recycling, input_storage)
if profiler is None:
def f():
for r in no_recycling:
r.data = None
try:
for thunk, op in zip(thunks, order):
thunk()
except:
raise_with_op(op)
else: else:
def f(): no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
for r in no_recycling:
r.data = None f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
def g():
for thunk, op in zip(thunks, order): return f, [Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \
profiler.profile_op(thunk, op) [Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \
profiler.profile_env(g, env) thunks, order
f.profiler = profiler
# return f, env.inputs, env.outputs
# class PerformLinker(Linker):
# """
# Basic L{Linker} subclass that calls the perform method on each L{Op} in
# the L{Env} in the order given by L{Env.toposort}.
# """
# def __init__(self, env, no_recycling = []):
# self.env = env
# self.no_recycling = no_recycling
# def make_thunk(self, inplace = False, profiler = None):
# if inplace:
# env = self.env
# else:
# env = self.env.clone(True)
# order = env.toposort()
# thunks = [op.perform for op in order]
# no_recycling = self.no_recycling
# if no_recycling is True:
# no_recycling = list(env.results())
# no_recycling = utils.difference(no_recycling, env.inputs)
# if profiler is None:
# def f():
# for r in no_recycling:
# r.data = None
# try:
# for thunk, op in zip(thunks, order):
# thunk()
# except:
# raise_with_op(op)
# else:
# def f():
# for r in no_recycling:
# r.data = None
# def g():
# for thunk, op in zip(thunks, order):
# profiler.profile_op(thunk, op)
# profiler.profile_env(g, env)
# f.profiler = profiler
return f, env.inputs, env.outputs # return f, env.inputs, env.outputs
......
...@@ -5,240 +5,78 @@ compatible with gof's graph manipulation routines. ...@@ -5,240 +5,78 @@ compatible with gof's graph manipulation routines.
""" """
import utils import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError from utils import AbstractFunctionError, object2
import graph
from copy import copy from copy import copy
__all__ = ['Op',
'GuardedOp',
]
class Op(object2):
default_output = 0
def constructor(op_cls, name = None):
"""
Make an L{Op} look like a L{Result}-valued function.
"""
def f(*args, **kwargs):
op = op_cls(*args, **kwargs)
if len(op.outputs) > 1:
return op.outputs
else:
return op.outputs[0]
opname = op_cls.__name__
if name is None:
name = "constructor{%s}" % opname
f.__name__ = name
doc = op_cls.__doc__
f.__doc__ = """
Constructor for %(opname)s:
%(doc)s
""" % locals()
return f
class Op(object):
"""
L{Op} represents a computation on the storage in its 'inputs' slot,
the results of which are stored in the L{Result} instances in the
'outputs' slot. The owner of each L{Result} in the outputs list must
be set to this L{Op} and thus any L{Result} instance is in the outputs
list of at most one L{Op}, its owner. It is the responsibility of the
L{Op} to ensure that it owns its outputs and it is encouraged (though
not required) that it creates them.
"""
__slots__ = ['_inputs', '_outputs', '_hash_id']
_default_output_idx = 0
def default_output(self):
"""Returns the default output of this Op instance, typically self.outputs[0]."""
try:
return self.outputs[self._default_output_idx]
except (IndexError, TypeError):
raise AttributeError("Op does not have a default output.")
out = property(default_output,
doc = "Same as self.outputs[0] if this Op's has_default_output field is True.")
def __init__(self, **kwargs):
self._hash_id = utils.hashgen()
#
# Python stdlib compatibility
#
# These are defined so that sets of Ops, Results will have a consistent
# ordering
def __cmp__(self, other):
return cmp(id(self), id(other))
def __eq__(self, other):
return self is other #assuming this is faster, equiv to id(self) == id(other)
def __ne__(self, other):
return self is not other #assuming this is faster, equiv to id(self) != id(other)
def __hash__(self):
if not hasattr(self, '_hash_id'):
self._hash_id = utils.hashgen()
return self._hash_id
def desc(self):
"""
Description (signature) of this L{Op}. L{Op}s with the same
signature may be collapsed by the L{MergeOptimizer}.
@attention: If your L{Op} has additional options or a different
constructor you probably want to override this.
"""
return self.__class__
def strdesc(self):
return self.__class__.__name__
#
#
#
def get_input(self, i):
return self._inputs[i]
def set_input(self, i, new):
self._inputs[i] = new
def get_inputs(self):
return self._inputs
def set_inputs(self, new):
self._inputs = list(new)
def get_output(self, i):
return self._outputs[i]
def get_outputs(self): #############
return self._outputs # make_node #
def set_outputs(self, new): #############
"""
The point of this function is:
1. to save the subclass's __init__ function always having to set the role of the outputs
2. to prevent accidentally re-setting outputs, which would probably be a bug
"""
if not hasattr(self, '_outputs') or self._outputs is None:
for i, output in enumerate(new):
output.role = (self, i)
self._outputs = list(new)
else:
raise Exception("Can only set outputs once, to initialize them.")
#create inputs and outputs as read-only attributes
inputs = property(get_inputs, set_inputs, doc = "The list of this Op's input Results.")
outputs = property(get_outputs, set_outputs, doc = "The list of this Op's output Results.")
def make_node(self, *inputs):
raise AbstractFunctionError()
# def __call__(self, *inputs):
# copy return self.make_node(*inputs).out
#
def __copy__(self):
"""
Shallow copy of this L{Op}. The inputs are the exact same, but
the outputs are recreated because of the one-owner-per-result
policy. The default behavior is to call the constructor on this
L{Op}'s inputs.
To do a bottom-up copy of a graph, use L{clone_with_new_inputs}. #########################
# Python implementation #
#########################
@attention: If your L{Op} has additional options or a different def impl(self, node, inputs, output_storage):
constructor you probably want to override this.
""" """
return self.__class__(*self.inputs) Calculate the function on the inputs and put the results in the
output storage.
def clone_with_new_inputs(self, *new_inputs): - inputs: sequence of inputs (immutable)
""" - outputs: mutable list
Returns a clone of this L{Op} that takes different inputs. The
default behavior is to call the constructor on the new inputs.
@attention: If your L{Op} has additional options or a different The output_storage list might contain data. If an element of
constructor you probably want to override this. output_storage is not None, it is guaranteed that it was produced
""" by a previous call to impl and impl is free to reuse it as it
return self.__class__(*new_inputs) sees fit.
#
# String representation
#
def __str__(self):
return graph.op_as_string(self.inputs, self)
def __repr__(self):
return str(self)
#
# perform
#
def impl(self, *args):
"""Return output data [tuple], given input data
If this L{Op} has a single output (len(self.outputs)==1) then the return
value of this function will be assigned to self.outputs[0].data.
If this L{Op} has multiple otuputs, then this function should return a
tuple with the data for outputs[0], outputs[1], outputs[2], etc.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def perform(self):
"""
Performs the computation associated to this L{Op} and places the
result(s) in the output L{Result}s.
TODO: consider moving this function to the python linker.
"""
res = self.impl(*[input.data for input in self.inputs])
if len(self.outputs) == 1:
self.outputs[0].data = res
else:
assert len(res) == len(self.outputs)
for output, value in zip(self.outputs, res):
output.data = value
#####################
# C code generation #
#####################
# # def c_validate_update(self, inputs, outputs, sub):
# C code generators # """
# # Returns templated C code that checks that the inputs to this
# function can be worked on. If a failure occurs, set an
def c_validate_update(self, inputs, outputs, sub): # Exception and insert "%(fail)s".
"""
Returns templated C code that checks that the inputs to this
function can be worked on. If a failure occurs, set an
Exception and insert "%(fail)s".
You may use the variable names defined by c_var_names() in # You may use the variable names defined by c_var_names() in
the template. # the template.
Note: deprecated!! # Note: deprecated!!
@todo: Merge this with c_code. # @todo: Merge this with c_code.
""" # """
raise AbstractFunctionError() # raise AbstractFunctionError()
def c_validate_update_cleanup(self, inputs, outputs, sub): # def c_validate_update_cleanup(self, inputs, outputs, sub):
""" # """
Clean up things allocated by L{c_validate}(). # Clean up things allocated by L{c_validate}().
Note: deprecated!! # Note: deprecated!!
@todo: Merge this with c_code. # @todo: Merge this with c_code.
""" # """
raise AbstractFunctionError() # raise AbstractFunctionError()
raise AbstractFunctionError('%s.c_validate_update_cleanup ' \ # raise AbstractFunctionError('%s.c_validate_update_cleanup ' \
% self.__class__.__name__) # % self.__class__.__name__)
def c_code(self, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
"""Return the C implementation of an Op. """Return the C implementation of an Op.
Returns templated C code that does the computation associated Returns templated C code that does the computation associated
...@@ -262,7 +100,7 @@ class Op(object): ...@@ -262,7 +100,7 @@ class Op(object):
raise AbstractFunctionError('%s.c_code' \ raise AbstractFunctionError('%s.c_code' \
% self.__class__.__name__) % self.__class__.__name__)
def c_code_cleanup(self, inputs, outputs, sub): def c_code_cleanup(self, node, name, inputs, outputs, sub):
"""Code to be run after c_code, whether it failed or not. """Code to be run after c_code, whether it failed or not.
This is a convenient place to clean up things allocated by c_code(). This is a convenient place to clean up things allocated by c_code().
...@@ -297,26 +135,40 @@ class Op(object): ...@@ -297,26 +135,40 @@ class Op(object):
raise AbstractFunctionError() raise AbstractFunctionError()
#TODO: consider adding a flag to the base class that toggles this behaviour class PropertiedOp(Op):
class GuardedOp(Op):
"""An Op that disallows input properties to change after construction""" def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__
def set_input(self, i, new):
old = self._inputs[i] def __str__(self):
if old is new: if hasattr(self, 'name') and self.name:
return return self.name
try:
if not old.same_properties(new):
raise TypeError("The new input must have the same properties as the previous one.")
except AbstractFunctionError:
pass
Op.set_input(self, i, new)
def set_inputs(self, new):
if not hasattr(self, '_inputs') or self_inputs is None:
Op.set_inputs(self, new)
else: else:
if not len(new) == len(self._inputs): return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name"))
raise TypeError("The new inputs are not as many as the previous ones.")
for i, new in enumerate(new):
self.set_input(i, new)
# #TODO: consider adding a flag to the base class that toggles this behaviour
# class GuardedOp(Op):
# """An Op that disallows input properties to change after construction"""
# def set_input(self, i, new):
# old = self._inputs[i]
# if old is new:
# return
# try:
# if not old.same_properties(new):
# raise TypeError("The new input must have the same properties as the previous one.")
# except AbstractFunctionError:
# pass
# Op.set_input(self, i, new)
# def set_inputs(self, new):
# if not hasattr(self, '_inputs') or self_inputs is None:
# Op.set_inputs(self, new)
# else:
# if not len(new) == len(self._inputs):
# raise TypeError("The new inputs are not as many as the previous ones.")
# for i, new in enumerate(new):
# self.set_input(i, new)
from op import Op from op import Op
from result import Result from graph import Constant
from type import Type
from env import InconsistencyError from env import InconsistencyError
import utils import utils
import unify import unify
...@@ -30,12 +32,15 @@ class Optimizer: ...@@ -30,12 +32,15 @@ class Optimizer:
env.satisfy(opt) env.satisfy(opt)
opt.apply(env) opt.apply(env)
""" """
env.satisfy(self) self.add_requirements(env)
self.apply(env) self.apply(env)
def __call__(self, env): def __call__(self, env):
return self.optimize(env) return self.optimize(env)
def add_requirements(self, env):
pass
DummyOpt = Optimizer() DummyOpt = Optimizer()
DummyOpt.__doc__ = "Does nothing." DummyOpt.__doc__ = "Does nothing."
...@@ -81,13 +86,13 @@ class LocalOptimizer(Optimizer): ...@@ -81,13 +86,13 @@ class LocalOptimizer(Optimizer):
def candidates(self, env): def candidates(self, env):
""" """
Must return a set of ops that can be optimized. Must return a set of nodes that can be optimized.
""" """
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
def apply_on_op(self, env, op): def apply_on_node(self, env, node):
""" """
For each op in candidates, this function will be called to For each node in candidates, this function will be called to
perform the actual optimization. perform the actual optimization.
""" """
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
...@@ -96,26 +101,27 @@ class LocalOptimizer(Optimizer): ...@@ -96,26 +101,27 @@ class LocalOptimizer(Optimizer):
""" """
Calls self.apply_on_op(env, op) for each op in self.candidates(env). Calls self.apply_on_op(env, op) for each op in self.candidates(env).
""" """
for op in self.candidates(env): for node in self.candidates(env):
if env.has_op(op): if env.has_node(node):
self.apply_on_op(env, op) self.apply_on_node(env, node)
class OpSpecificOptimizer(LocalOptimizer): class OpSpecificOptimizer(LocalOptimizer):
""" """
Generic L{Optimizer} that applies only to ops of a certain Generic L{Optimizer} that applies only to ops of a certain
type. The type in question is accessed through L{self.opclass}. type. The type in question is accessed through L{self.op}.
opclass can also be a class variable of the subclass. op can also be a class variable of the subclass.
""" """
__env_require__ = toolbox.InstanceFinder def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
def candidates(self, env): def candidates(self, env):
""" """
Returns all instances of L{self.opclass}. Returns all instances of L{self.op}.
""" """
return env.get_instances_of(self.opclass) return env.get_nodes(self.op)
...@@ -128,7 +134,8 @@ class OpSubOptimizer(Optimizer): ...@@ -128,7 +134,8 @@ class OpSubOptimizer(Optimizer):
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
""" """
__env_require__ = toolbox.InstanceFinder def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
def __init__(self, op1, op2, failure_callback = None): def __init__(self, op1, op2, failure_callback = None):
""" """
...@@ -149,38 +156,40 @@ class OpSubOptimizer(Optimizer): ...@@ -149,38 +156,40 @@ class OpSubOptimizer(Optimizer):
the Optimizer fails to do a replacement in the graph. The the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception) arguments to the callback are: (op1_instance, replacement, exception)
""" """
candidates = env.get_instances_of(self.op1) candidates = env.get_nodes(self.op1)
for op in candidates: for node in candidates:
try: try:
repl = self.op2(*op.inputs) repl = self.op2.make_node(*node.inputs)
assert len(op.outputs) == len(repl.outputs) assert len(node.outputs) == len(repl.outputs)
for old, new in zip(op.outputs, repl.outputs): for old, new in zip(node.outputs, repl.outputs):
env.replace(old, new) env.replace(old, new)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(op, repl, e) self.failure_callback(node, repl, e)
pass pass
def str(self): def str(self):
return "%s -> %s" % (self.op1.__name__, self.op2.__name__) return "%s -> %s" % (self.op1, self.op2)
class OpRemover(Optimizer): class OpRemover(Optimizer):
""" """
@todo untested
Removes all ops of a certain type by transferring each of its Removes all ops of a certain type by transferring each of its
outputs to the corresponding input. outputs to the corresponding input.
""" """
__env_require__ = toolbox.InstanceFinder def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
def __init__(self, opclass, failure_callback = None): def __init__(self, op, failure_callback = None):
""" """
opclass is the class of the ops to remove. It must take as opclass is the class of the ops to remove. It must take as
many inputs as outputs. many inputs as outputs.
""" """
self.opclass = opclass self.op = op
self.failure_callback = failure_callback self.failure_callback = failure_callback
def apply(self, env): def apply(self, env):
...@@ -192,25 +201,27 @@ class OpRemover(Optimizer): ...@@ -192,25 +201,27 @@ class OpRemover(Optimizer):
arguments to the callback are: (opclass_instance, exception) arguments to the callback are: (opclass_instance, exception)
""" """
candidates = env.get_instances_of(self.opclass) candidates = env.get_nodes(self.op)
for op in candidates: for node in candidates:
try: try:
assert len(op.inputs) == len(op.outputs) assert len(node.inputs) == len(node.outputs)
for input, output in zip(op.inputs, op.outputs): for input, output in zip(node.inputs, node.outputs):
env.replace(output, input) env.replace(output, input)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(op, e) self.failure_callback(node, e)
pass pass
def str(self): def str(self):
return "f(%s(x)) -> f(x)" % self.opclass return "f(%s(x)) -> f(x)" % self.op
class PatternOptimizer(OpSpecificOptimizer): class PatternOptimizer(OpSpecificOptimizer):
""" """
@todo update
Replaces all occurrences of the input pattern by the output pattern:: Replaces all occurrences of the input pattern by the output pattern::
input_pattern ::= (OpClass, <sub_pattern1>, <sub_pattern2>, ...) input_pattern ::= (OpClass, <sub_pattern1>, <sub_pattern2>, ...)
...@@ -253,23 +264,19 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -253,23 +264,19 @@ class PatternOptimizer(OpSpecificOptimizer):
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None): def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)): if isinstance(in_pattern, (list, tuple)):
self.opclass = self.in_pattern[0] self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict): elif isinstance(in_pattern, dict):
self.opclass = self.in_pattern['pattern'][0] self.op = self.in_pattern['pattern'][0]
else: else:
raise TypeError("The pattern to search for must start with a specific Op class.") raise TypeError("The pattern to search for must start with a specific Op instance.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback self.failure_callback = failure_callback
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
def apply_on_op(self, env, op): def apply_on_node(self, env, node):
""" """
Checks if the graph from op corresponds to in_pattern. If it does, Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement. constructs out_pattern and performs the replacement.
...@@ -283,116 +290,9 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -283,116 +290,9 @@ class PatternOptimizer(OpSpecificOptimizer):
""" """
def match(pattern, expr, u, first = False): def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if not issubclass(expr.owner.__class__, pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1): if expr.owner is None:
return False return False
if len(pattern) - 1 != len(expr.owner.inputs): if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
if constraint(env, expr):
return match(real_pattern, expr, u, False)
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return False
else:
u = u.merge(expr, v)
elif isinstance(pattern, Result) \
and getattr(pattern, 'constant', False) \
and isinstance(expr, Result) \
and getattr(expr, 'constant', False) \
and pattern.desc() == expr.desc():
return u
else:
return False
return u
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args).out
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
u = match(self.in_pattern, op.out, unify.Unification(), True)
if u:
try:
# note: only replaces the default 'out' port if it exists
p = self.out_pattern
new = 'unassigned'
new = build(p, u)
env.replace(op.out, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op.out, new, e)
pass
def __str__(self):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0].__name__, ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class PatternDescOptimizer(LocalOptimizer):
"""
"""
__env_require__ = toolbox.DescFinder
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)):
self.desc = self.in_pattern[0]
elif isinstance(in_pattern, dict):
self.desc = self.in_pattern['pattern'][0]
else:
raise TypeError("The pattern to search for must start with a specific desc.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback
self.allow_multiple_clients = allow_multiple_clients
def candidates(self, env):
"""
Returns all instances of self.desc
"""
return env.get_from_desc(self.desc)
def apply_on_op(self, env, op):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
If self.allow_multiple_clients is False, he pattern matching will fail
if one of the subpatterns has more than one client.
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not expr.owner or not expr.owner.desc() == pattern[0] or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
return False return False
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return False return False
...@@ -414,11 +314,7 @@ class PatternDescOptimizer(LocalOptimizer): ...@@ -414,11 +314,7 @@ class PatternDescOptimizer(LocalOptimizer):
return False return False
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, Result) \ elif isinstance(pattern, Constant) and isinstance(expr, Constant) and pattern.equals(expr):
and getattr(pattern, 'constant', False) \
and isinstance(expr, Result) \
and getattr(expr, 'constant', False) \
and pattern.desc() == expr.desc():
return u return u
else: else:
return False return False
...@@ -427,29 +323,29 @@ class PatternDescOptimizer(LocalOptimizer): ...@@ -427,29 +323,29 @@ class PatternDescOptimizer(LocalOptimizer):
def build(pattern, u): def build(pattern, u):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]] args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args).out return pattern[0](*args)
elif isinstance(pattern, str): elif isinstance(pattern, str):
return u[unify.Var(pattern)] return u[unify.Var(pattern)]
else: else:
return pattern return pattern
u = match(self.in_pattern, op.out, unify.Unification(), True) u = match(self.in_pattern, node.out, unify.Unification(), True)
if u: if u:
try: try:
# note: only replaces the default 'out' port if it exists # note: only replaces the default 'out' port if it exists
p = self.out_pattern p = self.out_pattern
new = 'unassigned' new = 'unassigned' # this is for the callback if build fails
new = build(p, u) new = build(p, u)
env.replace(op.out, new) env.replace(node.out, new)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(op.out, new, e) self.failure_callback(node.out, new, e)
pass pass
def __str__(self): def __str__(self):
def pattern_to_str(pattern): def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0], ", ".join([pattern_to_str(p) for p in pattern[1:]])) return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint'])) return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
else: else:
...@@ -458,28 +354,59 @@ class PatternDescOptimizer(LocalOptimizer): ...@@ -458,28 +354,59 @@ class PatternDescOptimizer(LocalOptimizer):
class ConstantFinder(Optimizer): # class ConstantFinder(Optimizer):
""" # """
Sets as constant every orphan that is not destroyed. # Sets as constant every orphan that is not destroyed.
""" # """
def apply(self, env): # def apply(self, env):
if env.has_feature(ext.DestroyHandler): # if env.has_feature(ext.DestroyHandler(env)):
for r in env.orphans(): # for r in env.orphans():
if not env.destroyers(r):
r.indestructible = True
r.constant = True
# for r in env.inputs:
# if not env.destroyers(r): # if not env.destroyers(r):
# r.indestructible = True # r.indestructible = True
else: # r.constant = True
for r in env.orphans(): # # for r in env.inputs:
r.indestructible = True # # if not env.destroyers(r):
r.constant = True # # r.indestructible = True
# for r in env.inputs: # else:
# for r in env.orphans():
# r.indestructible = True # r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # r.indestructible = True
import graph import graph
class _metadict:
# dict that accepts unhashable keys
# uses an associative list
# for internal use only
def __init__(self):
self.d = {}
self.l = []
def __getitem__(self, item):
return self.get(item, None)
def __setitem__(self, item, value):
try:
self.d[item] = value
except:
self.l.append((item, value))
def get(self, item, default):
try:
return self.d[item]
except:
for item2, value in self.l:
if item == item2:
return value
else:
return default
def clear(self):
self.d = {}
self.l = []
def __str__(self):
return "(%s, %s)" % (self.d, self.l)
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """
Merges parts of the graph that are identical, i.e. parts that Merges parts of the graph that are identical, i.e. parts that
...@@ -489,37 +416,41 @@ class MergeOptimizer(Optimizer): ...@@ -489,37 +416,41 @@ class MergeOptimizer(Optimizer):
""" """
def apply(self, env): def apply(self, env):
cid = {} #result -> result.desc() (for constants) cid = _metadict() #result -> result.desc() (for constants)
inv_cid = {} #desc -> result (for constants) inv_cid = _metadict() #desc -> result (for constants)
for i, r in enumerate(env.orphans().union(env.inputs)): for i, r in enumerate(env.orphans.union(env.inputs)):
if getattr(r, 'constant', False): if isinstance(r, Constant):
ref = ('const', r.desc()) sig = r.signature()
other_r = inv_cid.get(ref, None) other_r = inv_cid.get(sig, None)
if other_r is not None: if other_r is not None:
env.replace(r, other_r) env.replace(r, other_r)
else: else:
cid[r] = ref cid[r] = sig
inv_cid[ref] = r inv_cid[sig] = r
else: # we clear the dicts because the Constants signatures are not necessarily hashable
cid[r] = i # and it's more efficient to give them an integer cid like the other Results
inv_cid[i] = r cid.clear()
inv_cid.clear()
for op in env.io_toposort(): for i, r in enumerate(env.orphans.union(env.inputs)):
op_cid = (op.desc(), tuple([cid[input] for input in op.inputs])) cid[r] = i
dup = inv_cid.get(op_cid, None) inv_cid[i] = r
for node in env.io_toposort():
node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None)
success = False success = False
if dup is not None: if dup is not None:
success = True success = True
d = dict(zip(op.outputs, dup.outputs)) d = dict(zip(node.outputs, dup.outputs))
try: try:
env.replace_all(d) env.replace_all(d)
except Exception, e: except Exception, e:
success = False success = False
if not success: if not success:
cid[op] = op_cid cid[node] = node_cid
inv_cid[op_cid] = op inv_cid[node_cid] = node
for i, output in enumerate(op.outputs): for i, output in enumerate(node.outputs):
ref = (i, op_cid) ref = (i, node_cid)
cid[output] = ref cid[output] = ref
inv_cid[ref] = output inv_cid[ref] = output
......
from features import Listener, Tool
from random import shuffle from random import shuffle
import utils import utils
__all__ = ['EquivTool',
'InstanceFinder',
'DescFinder',
'PrintListener',
]
class EquivTool(dict):
class EquivTool(Listener, Tool, dict): def __init__(self, env):
self.env = env
def on_rewire(self, clients, r, new_r): def on_rewire(self, clients, r, new_r):
repl = self(new_r) repl = self(new_r)
...@@ -25,6 +21,10 @@ class EquivTool(Listener, Tool, dict): ...@@ -25,6 +21,10 @@ class EquivTool(Listener, Tool, dict):
self.env.equiv = self self.env.equiv = self
self.env.set_equiv = self.set_equiv self.env.set_equiv = self.set_equiv
def unpublish(self):
del self.env.equiv
del self.env.set_equiv
def set_equiv(self, d): def set_equiv(self, d):
self.update(d) self.update(d)
...@@ -56,71 +56,109 @@ class EquivTool(Listener, Tool, dict): ...@@ -56,71 +56,109 @@ class EquivTool(Listener, Tool, dict):
return key return key
class InstanceFinder(Listener, Tool, dict): class NodeFinder(dict):
def __init__(self, env): def __init__(self, env):
self.env = env self.env = env
def all_bases(self, cls): def on_import(self, node):
return utils.all_bases(cls, lambda cls: cls is not object) try:
self.setdefault(node.op, set()).add(node)
def on_import(self, op): except TypeError:
for base in self.all_bases(op.__class__): pass
self.setdefault(base, set()).add(op)
def on_prune(self, node):
def on_prune(self, op): try:
for base in self.all_bases(op.__class__): self[node.op].remove(node)
self[base].remove(op) except TypeError:
if not self[base]: return
del self[base] if not self[node.op]:
del self[node.op]
def __query__(self, cls):
all = [x for x in self.get(cls, [])] def query(self, op):
try:
all = self.get(op, [])
except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = [x for x in all]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
while all: while all:
next = all.pop() next = all.pop()
if next in self.env.ops(): if self.env.has_node(next):
yield next yield next
def query(self, cls):
return self.__query__(cls)
def publish(self): def publish(self):
self.env.get_instances_of = self.query self.env.get_nodes = self.query
def __eq__(self, other):
return isinstance(other, NodeFinder) and self.env is other.env
class DescFinder(Listener, Tool, dict): # class InstanceFinder(Listener, Tool, dict):
def __init__(self, env): # def __init__(self, env):
self.env = env # self.env = env
def on_import(self, op): # def all_bases(self, cls):
self.setdefault(op.desc(), set()).add(op) # return utils.all_bases(cls, lambda cls: cls is not object)
def on_prune(self, op): # def on_import(self, op):
desc = op.desc() # for base in self.all_bases(op.__class__):
self[desc].remove(op) # self.setdefault(base, set()).add(op)
if not self[desc]:
del self[desc]
def __query__(self, desc): # def on_prune(self, op):
all = [x for x in self.get(desc, [])] # for base in self.all_bases(op.__class__):
shuffle(all) # this helps for debugging because the order of the replacements will vary # self[base].remove(op)
while all: # if not self[base]:
next = all.pop() # del self[base]
if next in self.env.ops():
yield next
def query(self, desc): # def __query__(self, cls):
return self.__query__(desc) # all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
def publish(self): # def query(self, cls):
self.env.get_from_desc = self.query # return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class DescFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def on_import(self, op):
# self.setdefault(op.desc(), set()).add(op)
# def on_prune(self, op):
# desc = op.desc()
# self[desc].remove(op)
# if not self[desc]:
# del self[desc]
# def __query__(self, desc):
# all = [x for x in self.get(desc, [])]
# shuffle(all) # this helps for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, desc):
# return self.__query__(desc)
# def publish(self):
# self.env.get_from_desc = self.query
class PrintListener(Listener): class PrintListener(object):
def __init__(self, env, active = True): def __init__(self, env, active = True):
self.env = env self.env = env
...@@ -128,13 +166,13 @@ class PrintListener(Listener): ...@@ -128,13 +166,13 @@ class PrintListener(Listener):
if active: if active:
print "-- initializing" print "-- initializing"
def on_import(self, op): def on_import(self, node):
if self.active: if self.active:
print "-- importing: %s" % op print "-- importing: %s" % node
def on_prune(self, op): def on_prune(self, node):
if self.active: if self.active:
print "-- pruning: %s" % op print "-- pruning: %s" % node
def on_rewire(self, clients, r, new_r): def on_rewire(self, clients, r, new_r):
if self.active: if self.active:
......
"""
Contains the L{Result} class, which is the base interface for a
value that is the input or the output of an L{Op}.
"""
import copy import copy
import utils import utils
from utils import AbstractFunctionError from utils import AbstractFunctionError, object2
from graph import Result
__all__ = ['Result',
'PythonResult',
'StateError',
'Empty',
'Allocated',
'Computed',
]
### CLEANUP - DO WE REALLY EVEN THE STATE ANYMORE? ###
class StateError(Exception):
"""The state of the L{Result} is a problem"""
# Result state keywords
class Empty : """Memory has not been allocated"""
class Allocated: """Memory has been allocated, contents are not the owner's output."""
class Computed : """Memory has been allocated, contents are the owner's output."""
############################
# Result
############################
class Result(object):
"""
Base class for storing L{Op} inputs and outputs
Attributes:
- _role - None or (owner, index) #or BrokenLink
- _data - anything
- state - one of (Empty, Allocated, Computed)
- name - string
"""
__slots__ = ['_role', '_data', 'state', '_name', '_hash_id']
def __init__(self, role=None, name=None):
self._role = None
if role is not None:
self.role = role
self._data = [None]
self.state = Empty
self.name = name
self._hash_id = utils.hashgen()
#
# Python stdlib compatibility
#
def __cmp__(self, other):
return cmp(id(self), id(other))
def __eq__(self, other):
return self is other #assuming this is faster, equiv to id(self) == id(other)
def __ne__(self, other):
return self is not other #assuming this is faster, equiv to id(self) != id(other)
def __hash__(self):
return self._hash_id
def desc(self):
return id(self)
#
# role
#
def __get_role(self):
return self._role
def __set_role(self, role):
owner, index = role
if self._role is not None:
# this is either an error or a no-op
_owner, _index = self._role
if _owner is not owner:
raise ValueError("Result %s already has an owner." % self)
if _index != index:
raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index
#TODO: this doesn't work because many bits of code set the role before
# owner.outputs. Op.__init__ should do this I think. -JSB
#assert owner.outputs[index] is self
self._role = role
role = property(__get_role, __set_role, doc="(writeable)")
#
# owner
#
def __get_owner(self):
if self._role is None: return None
return self._role[0]
owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None (read-only)")
#
# index
#
def __get_index(self): ########
if self._role is None: return None # Type #
return self._role[1] ########
index = property(__get_index, class Type(object2):
doc = "position of self in owner's outputs, or None if role is None (read-only)")
def filter(self, data, strict = False):
#
# data
#
def __get_data(self):
return self._data[0]
def __set_data(self, data):
"""
Filters the data provided and sets the result in the storage.
""" """
if data is self._data[0]: Return data or an appropriately wrapped data. Raise an
return exception if the data is not of an acceptable type.
if data is None:
self._data[0] = None
self.state = Empty
return
try:
data = self.filter(data)
except AbstractFunctionError:
pass
self._data[0] = data
self.state = Computed
data = property(__get_data, __set_data,
doc = "The storage associated with this result (writeable)")
def filter(self, data):
"""
Raise an exception if the data is not of an acceptable type.
If a subclass overrides this function, L{__set_data} will use it
to check that the argument can be used properly. This gives a
subclass the opportunity to ensure that the contents of
L{self._data} remain sensible.
Returns data or an appropriately wrapped data. If strict is True, the data returned must be the same
as the data passed as an argument. If it is False, filter
may cast it to the appropriate type.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def make_result(self, name = None):
# return Result(self, name = name)
# C code generators
# def __call__(self, name = None):
return self.make_result(name)
def c_is_simple(self): def c_is_simple(self):
""" """
A hint to tell the compiler that this type is a builtin C A hint to tell the compiler that this type is a builtin C
...@@ -175,7 +36,7 @@ class Result(object): ...@@ -175,7 +36,7 @@ class Result(object):
""" """
return False return False
def c_literal(self): def c_literal(self, data):
raise AbstractFunctionError() raise AbstractFunctionError()
def c_declare(self, name, sub): def c_declare(self, name, sub):
...@@ -249,67 +110,26 @@ class Result(object): ...@@ -249,67 +110,26 @@ class Result(object):
L{Result}. L{Result}.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
#
# name
#
def __get_name(self):
if self._name:
return self._name
elif self._role:
return "%s.%i" % (self.owner.__class__, self.owner.outputs.index(self))
else:
return None
def __set_name(self, name):
if name is not None and not isinstance(name, str):
raise TypeError("Name is expected to be a string, or None.")
self._name = name
name = property(__get_name, __set_name,
doc = "Name of the Result.")
#
# String representation
#
class SingletonType(Type):
__instance = None
def __new__(cls):
if cls.__instance is None:
cls.__instance = Type.__new__(cls)
return cls.__instance
def __str__(self): def __str__(self):
name = self.name return self.__class__.__name__
if name:
if self.state is Computed:
return name + ":" + str(self.data)
else:
return name
elif self.state is Computed:
return str(self.data)
else:
return "<?>"
def __repr__(self):
return self.name or "<?>"
#
# same properties
#
def same_properties(self, other):
"""Return bool; True iff all properties are equal (ignores contents, role)"""
raise AbstractFunction()
class Generic(SingletonType):
def __copy__(self):
"""Create a new instance of self.__class__ with role None, independent data"""
raise AbstractFunctionError()
class PythonResult(Result):
""" """
Represents a generic Python object. The object is available Represents a generic Python object.
through %(name)s.
""" """
def filter(self, data, strict = False):
return data
def c_declare(self, name, sub): def c_declare(self, name, sub):
return """ return """
PyObject* %(name)s; PyObject* %(name)s;
...@@ -332,18 +152,12 @@ class PythonResult(Result): ...@@ -332,18 +152,12 @@ class PythonResult(Result):
py_%(name)s = %(name)s; py_%(name)s = %(name)s;
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
""" % locals() """ % locals()
def same_properties(self, other):
return False
def __copy__(self): generic = Generic()
rval = PythonResult(None, self.name)
rval.data = copy.copy(self.data)
return rval
def python_result(data, **kwargs): class PropertiedType(Type):
rval = PythonResult(**kwargs)
rval.data = data
return rval
def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__
...@@ -19,6 +19,23 @@ class AbstractFunctionError(Exception): ...@@ -19,6 +19,23 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
class object2(object):
__slots__ = []
def __hash__(self):
# this fixes silent-error-prone new-style class behavior
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self)
return id(self)
class scratchpad:
def clear(self):
self.__dict__.clear()
def __str__(self):
print "scratch" + str(self.__dict__)
def uniq(seq): def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB #TODO: consider building a set out of seq so that the if condition is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
......
import gof, gof.result import gof #, gof.result
import numpy #for numeric_grad import numpy #for numeric_grad
from gof.python25 import all from gof.python25 import all
...@@ -60,17 +60,17 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -60,17 +60,17 @@ def grad_sources_inputs(sources, graph_inputs):
if graph_inputs is None: if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs) graph_inputs = gof.graph.inputs(graph_outputs)
for op in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__(): for node in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__():
g_outputs = [gmap.get(o,None) for o in op.outputs] g_outputs = [gmap.get(o,None) for o in node.outputs]
#if all output gradients are None, continue #if all output gradients are None, continue
if all(map(lambda x:x is None, g_outputs)): continue if all(map(lambda x:x is None, g_outputs)): continue
output_arg = g_outputs output_arg = g_outputs
input_arg = op.inputs input_arg = node.inputs
try: try:
dinputs = [x[0] for x in op.destroy_map().values()] dinputs = [node.inputs[x[0]] for x in node.op.destroy_map.values()]
except AttributeError: except AttributeError:
dinputs = [] dinputs = []
...@@ -90,17 +90,17 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -90,17 +90,17 @@ def grad_sources_inputs(sources, graph_inputs):
# Other possibilities: # Other possibilities:
# * return a partial back-prop # * return a partial back-prop
# #
op_grad = op.grad(input_arg, output_arg) op_grad = node.op.grad(input_arg, output_arg)
if not isinstance(op_grad, (list,tuple)): if not isinstance(op_grad, (list,tuple)):
raise ValueError(_msg_retType, op.__class__) raise ValueError(_msg_retType, node.op)
g_inputs = op_grad #_pack_result(op_grad) g_inputs = op_grad #_pack_result(op_grad)
assert isinstance(g_inputs, (list, tuple)) assert isinstance(g_inputs, (list, tuple))
if len(g_inputs) != len(op.inputs): if len(g_inputs) != len(node.inputs):
raise ValueError(_msg_badlen, raise ValueError(_msg_badlen,
op.__class__, node.op,
len(g_inputs), len(g_inputs),
len(op.inputs)) len(node.inputs))
for r, g_r in zip(op.inputs, g_inputs): for r, g_r in zip(node.inputs, g_inputs):
if g_r is not None: if g_r is not None:
if r in gmap: if r in gmap:
gmap[r] = gmap[r] + g_r gmap[r] = gmap[r] + g_r
......
...@@ -7,61 +7,57 @@ from copy import copy ...@@ -7,61 +7,57 @@ from copy import copy
from functools import partial from functools import partial
import gof import gof
from gof import Result, GuardedOp, Env, utils from gof import PropertiedType, Op, PropertiedOp, utils, Result, Constant, Type, Apply, Env
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
def as_scalar(x, name = None): def as_scalar(x, name = None):
if isinstance(x, gof.Op): if isinstance(x, gof.Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", x) raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", x)
else: else:
x = x.outputs[0] x = x.outputs[0]
if isinstance(x, float): if isinstance(x, Result):
s = Scalar('float64', name = name) if not isinstance(x.type, Scalar):
s.data = x raise TypeError("Result type field must be a Scalar.", x, x.type)
return s
if isinstance(x, int):
s = Scalar('int32', name = name)
s.data = x
return s
if isinstance(x, Scalar):
return x return x
raise TypeError("Cannot convert %s to Scalar" % x) if isinstance(x, Constant):
if not isinstance(x.type, Scalar):
raise TypeError("Constant type field must be a Scalar.", x, x.type)
return x
try:
return constant(x)
except TypeError:
raise TypeError("Cannot convert %s to Scalar" % x, type(x))
def constant(x): def constant(x):
res = as_scalar(x) if isinstance(x, float):
res.constant = True return ScalarConstant(float64, x)
return res if isinstance(x, int):
return ScalarConstant(int64, x)
return ScalarConstant(float64, float(x))
class Scalar(Result): class Scalar(Type):
def __init__(self, dtype, name = None): def __init__(self, dtype):
Result.__init__(self, role = None, name = name)
self.dtype = dtype self.dtype = dtype
self.dtype_specs() self.dtype_specs() # error checking
def __get_constant(self):
if not hasattr(self, '_constant'):
return False
return self._constant
def __set_constant(self, value):
if value:
self.indestructible = True
self._constant = value
constant = property(__get_constant, __set_constant)
def desc(self):
return (self.dtype, self.data)
def filter(self, data): def filter(self, data, strict = False):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict: assert isinstance(data, py_type)
return py_type(data) return py_type(data)
def same_properties(self, other): def __eq__(self, other):
return other.dtype == self.dtype return type(self) == type(other) and other.dtype == self.dtype
def __hash__(self):
return hash(self.dtype)
def dtype_specs(self): def dtype_specs(self):
try: try:
...@@ -77,10 +73,22 @@ class Scalar(Result): ...@@ -77,10 +73,22 @@ class Scalar(Result):
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
def c_literal(self): def upcast(self, *others):
return upcast(*[x.dtype for x in [self]+list(others)])
def make_result(self, name = None):
return ScalarResult(self, name = name)
def __str__(self):
return str(self.dtype)
def __repr__(self):
return "Scalar{%s}" % self.dtype
def c_literal(self, data):
if 'complex' in self.dtype: if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.") raise NotImplementedError("No literal for complex values.")
return str(self.data) return str(data)
def c_declare(self, name, sub): def c_declare(self, name, sub):
return """ return """
...@@ -119,7 +127,7 @@ class Scalar(Result): ...@@ -119,7 +127,7 @@ class Scalar(Result):
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
def c_support_code(cls): def c_support_code(self):
template = """ template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{ {
...@@ -155,15 +163,25 @@ class Scalar(Result): ...@@ -155,15 +163,25 @@ class Scalar(Result):
""" """
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64) return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
def __copy__(self):
"""Return a copy of this instance (with its own attributes)""" int8 = Scalar('int8')
cpy = self.__class__(self.dtype, self.name) int16 = Scalar('int16')
cpy.data = self.data int32 = Scalar('int32')
return cpy int64 = Scalar('int64')
float32 = Scalar('float32')
float64 = Scalar('float64')
complex64 = Scalar('complex64')
complex128 = Scalar('complex128')
int_types = int8, int16, int32, int64
float_types = float32, float64
complex_types = complex64, complex128
class _scalar_py_operators:
#UNARY #UNARY
def __abs__(self): return Abs(self).out def __abs__(self): return _abs(self)
def __neg__(self): return Neg(self).out def __neg__(self): return neg(self)
#CASTS #CASTS
def __int__(self): return AsInt(self).out def __int__(self): return AsInt(self).out
...@@ -190,6 +208,12 @@ class Scalar(Result): ...@@ -190,6 +208,12 @@ class Scalar(Result):
def __rdiv__(self,other): return div(other,self) def __rdiv__(self,other): return div(other,self)
def __rpow__(self,other): return pow(other,self) def __rpow__(self,other): return pow(other,self)
class ScalarResult(Result, _scalar_py_operators):
pass
class ScalarConstant(Constant, _scalar_py_operators):
pass
# Easy constructors # Easy constructors
...@@ -204,60 +228,102 @@ def _multi(*fns): ...@@ -204,60 +228,102 @@ def _multi(*fns):
else: else:
return [partial(f2, f) for f in fns] return [partial(f2, f) for f in fns]
def intr(name): ints = _multi(int64)
return Scalar(name = name, dtype = 'int64') floats = _multi(float64)
ints = _multi(intr)
def floatr(name):
return Scalar(name = name, dtype = 'float64')
floats = _multi(floatr)
def upcast(dtype, *dtypes): def upcast_out(*types):
z = numpy.zeros((), dtype = dtype) return Scalar(dtype = Scalar.upcast(*types)),
for dtype in dtypes: def same_out(type):
z = z + numpy.zeros((), dtype = dtype) return type,
return str(z.dtype) def transfer_type(i):
assert type(i) == int
def f(*types):
return types[i],
f.__name__ = "transfer_type_%i" % i
return f
def specific_out(*spec):
def f(*types):
return spec
return f
def int_out(*types):
return int64,
def float_out(*types):
return float64,
def upgrade_to_float(*types):
conv = {int8: float32,
int16: float32,
int32: float64,
int64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])),
class ScalarOp(GuardedOp): class ScalarOp(Op):
nin = -1 nin = -1
nout = 1 nout = 1
def __init__(self, *inputs): def __init__(self, output_types_preference = None, name = None):
self.name = name
if output_types_preference is not None:
if not callable(output_types_preference):
raise TypeError("Expected a callable for the 'output_types_preference' argument to %s." % self.__class__)
self.output_types_preference = output_types_preference
def make_node(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \ raise TypeError("Wrong number of inputs for %s.make_node (got %i, expected %i)" \
% (self.__class__.__name__, len(inputs), self.nin)) % (self, len(inputs), self.nin))
else:
self.nin = len(inputs)
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] outputs = [t() for t in self.output_types([input.type for input in inputs])]
o_dtypes = self.output_dtypes(*i_dtypes) if len(outputs) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
self.inputs = inputs % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
self.outputs = [Scalar(dtype) for dtype in o_dtypes] return Apply(self, inputs, outputs)
def output_types(self, types):
if hasattr(self, 'output_types_preference'):
results = self.output_types_preference(*types)
if not isinstance(results, (list, tuple)) or any(not isinstance(x, Type) for x in results):
raise TypeError("output_types_preference should return a list or a tuple of types", self.output_types_preference, results)
if len(results) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s) by %s. Expected %s, got ?s."
% (self, ", ".join(str(input.type) for input in inputs),
self.output_types_preference, self.nout, len(results)))
return results
else:
raise NotImplementedError("Cannot calculate the output types for %s" % self)
def output_dtypes(self, *dtypes): def perform(self, node, inputs, output_storage):
if self.nout != 1: if self.nout == 1:
raise NotImplementedError() output_storage[0][0] = self.impl(*inputs)
return upcast(*dtypes), else:
results = utils.from_return_values(self.impl(*inputs))
assert len(results) == len(output_storage)
for storage, result in zip(output_storage, results):
storage[0] = result
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise AbstractFunctionError()
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
raise AbstractFunctionError() raise AbstractFunctionError()
def perform(self): def __eq__(self, other):
if self.nout == 1: return type(self) == type(other) and self.output_types_preference == other.output_types_preference
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def __hash__(self):
return hash(self.output_types_preference)
def __str__(self):
if hasattr(self, 'name') and self.name:
return self.name
else: else:
results = utils.from_return_values(self.impl(*[input.data for input in self.inputs])) return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name"))
for output, result in zip(self.outputs, results):
output.data = result
class UnaryScalarOp(ScalarOp): class UnaryScalarOp(ScalarOp):
nin = 1 nin = 1
...@@ -265,12 +331,6 @@ class UnaryScalarOp(ScalarOp): ...@@ -265,12 +331,6 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
nin = 2 nin = 2
class FloatUnaryScalarOp(UnaryScalarOp):
def output_dtypes(self, input_dtype):
if 'int' in input_dtype: return 'float64',
if 'float' in input_dtype: return input_dtype,
raise NotImplementedError()
class Add(ScalarOp): class Add(ScalarOp):
...@@ -279,13 +339,14 @@ class Add(ScalarOp): ...@@ -279,13 +339,14 @@ class Add(ScalarOp):
associative = True associative = True
def impl(self, *inputs): def impl(self, *inputs):
return sum(inputs) return sum(inputs)
def c_code(self, inputs, (z, ), sub): def c_code(self, node, name, inputs, (z, ), sub):
if not inputs: if not inputs:
return z + " = 0;" return z + " = 0;"
else: else:
return z + " = " + " + ".join(inputs) + ";" return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return (gz, ) * len(inputs) return (gz, ) * len(inputs)
add = Add(upcast_out, name = 'add')
class Mul(ScalarOp): class Mul(ScalarOp):
identity = 1 identity = 1
...@@ -293,7 +354,7 @@ class Mul(ScalarOp): ...@@ -293,7 +354,7 @@ class Mul(ScalarOp):
associative = True associative = True
def impl(self, *inputs): def impl(self, *inputs):
return numpy.product(inputs) return numpy.product(inputs)
def c_code(self, inputs, (z, ), sub): def c_code(self, node, name, inputs, (z, ), sub):
if not inputs: if not inputs:
return z + " = 1;" return z + " = 1;"
else: else:
...@@ -301,67 +362,74 @@ class Mul(ScalarOp): ...@@ -301,67 +362,74 @@ class Mul(ScalarOp):
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return [mul(*([gz] + utils.difference(inputs, [input]))) return [mul(*([gz] + utils.difference(inputs, [input])))
for input in inputs] for input in inputs]
mul = Mul(upcast_out, name = 'mul')
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x - y return x - y
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, -gz return gz, -gz
sub = Sub(upcast_out, name = 'sub')
class Div(BinaryScalarOp): class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if 'int' in self.inputs[0].dtype and 'int' in self.inputs[1].dtype: if node.inputs[0].type in int_types and node.inputs[1].type in int_types:
raise NotImplementedError("For integer arguments the behavior of division in C and in Python differ when the quotient is negative (to implement).") raise NotImplementedError("For integer arguments the behavior of division in C and in Python differ when the quotient is negative (to implement).")
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz / y, -(gz * x) / (y * y) return gz / y, -(gz * x) / (y * y)
div = Div(upcast_out, name = 'div')
class Pow(BinaryScalarOp): class Pow(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x ** y return x ** y
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz * y * x**(y - 1), gz * log(x) * x**y return gz * y * x**(y - 1), gz * log(x) * x**y
pow = Pow(upcast_out, name = 'pow')
class First(BinaryScalarOp): class First(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x return x
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, None return gz, None
first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return y return y
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals() return "%(z)s = %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return None, gz return None, gz
second = Second(transfer_type(1), name = 'second')
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x return x
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz, return gz,
identity = Identity(same_out, name = 'identity')
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz, return -gz,
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name = 'neg')
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
#TODO: for complex input, output is some flavour of float #TODO: for complex input, output is some flavour of float
...@@ -369,14 +437,15 @@ class Abs(UnaryScalarOp): ...@@ -369,14 +437,15 @@ class Abs(UnaryScalarOp):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sgn(x), return gz * sgn(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
dtype = str(self.inputs[0].dtype) type = node.inputs[0].type
if 'int' in dtype: if type in int_types:
return "%(z)s = abs(%(x)s);" % locals() return "%(z)s = abs(%(x)s);" % locals()
if 'float' in dtype: if type in float_types:
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
#complex, other? #complex, other?
raise NotImplementedError('dtype not supported', dtype) raise NotImplementedError('type not supported', type)
abs = Abs(same_out)
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -384,84 +453,94 @@ class Sgn(UnaryScalarOp): ...@@ -384,84 +453,94 @@ class Sgn(UnaryScalarOp):
return numpy.sign(x) return numpy.sign(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return None, return None,
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#casting is done by compiler #casting is done by compiler
#TODO: use copysign #TODO: use copysign
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals() return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
sgn = Sgn(same_out, name = 'abs')
class Inv(FloatUnaryScalarOp): class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1.0 / x return 1.0 / x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz / (x * x), return -gz / (x * x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = 1.0 / %(x)s;" % locals() return "%(z)s = 1.0 / %(x)s;" % locals()
inv = Inv(upgrade_to_float, name = 'inv')
class Log(FloatUnaryScalarOp): class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / x, return gz / x,
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" % locals() return "%(z)s = log(%(x)s);" % locals()
log = Log(upgrade_to_float, name = 'log')
class Log2(FloatUnaryScalarOp): class Log2(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / (x * math.log(2.0)), return gz / (x * math.log(2.0)),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals() return "%(z)s = log2(%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name = 'log2')
class Exp(FloatUnaryScalarOp): class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * exp(x), return gz * exp(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals() return "%(z)s = exp(%(x)s);" % locals()
exp = Exp(upgrade_to_float, name = 'exp')
class Sqr(UnaryScalarOp): class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x*x return x*x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * x * 2, return gz * x * 2,
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
sqr = Sqr(same_out, name = 'sqr')
class Sqrt(FloatUnaryScalarOp): class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sqrt(x) return math.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return (gz * 0.5) / sqrt(x), return (gz * 0.5) / sqrt(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals() return "%(z)s = sqrt(%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name = 'sqrt')
class Cos(FloatUnaryScalarOp): class Cos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cos(x) return math.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz * sin(x), return -gz * sin(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals() return "%(z)s = cos(%(x)s);" % locals()
cos = Cos(upgrade_to_float, name = 'cos')
class Sin(FloatUnaryScalarOp): class Sin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return math.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * cos(x), return gz * cos(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
sin = Sin(upgrade_to_float, name = 'sin')
class Tan(FloatUnaryScalarOp): class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tan(x) return math.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / (cos(x) ** 2), return gz / (cos(x) ** 2),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals() return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan')
class Cosh(FloatUnaryScalarOp): class Cosh(UnaryScalarOp):
""" """
sinh(x) = (exp(x) + exp(-x)) / 2 sinh(x) = (exp(x) + exp(-x)) / 2
""" """
...@@ -469,10 +548,11 @@ class Cosh(FloatUnaryScalarOp): ...@@ -469,10 +548,11 @@ class Cosh(FloatUnaryScalarOp):
return math.cosh(x) return math.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sinh(x), return gz * sinh(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals() return "%(z)s = cosh(%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name = 'cosh')
class Sinh(FloatUnaryScalarOp): class Sinh(UnaryScalarOp):
""" """
sinh(x) = (exp(x) - exp(-x)) / 2 sinh(x) = (exp(x) - exp(-x)) / 2
""" """
...@@ -480,10 +560,11 @@ class Sinh(FloatUnaryScalarOp): ...@@ -480,10 +560,11 @@ class Sinh(FloatUnaryScalarOp):
return math.sinh(x) return math.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * cosh(x), return gz * cosh(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sinh(%(x)s);" % locals() return "%(z)s = sinh(%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name = 'sinh')
class Tanh(FloatUnaryScalarOp): class Tanh(UnaryScalarOp):
""" """
tanh(x) = sinh(x) / cosh(x) tanh(x) = sinh(x) / cosh(x)
= (exp(2*x) - 1) / (exp(2*x) + 1) = (exp(2*x) - 1) / (exp(2*x) + 1)
...@@ -492,124 +573,740 @@ class Tanh(FloatUnaryScalarOp): ...@@ -492,124 +573,740 @@ class Tanh(FloatUnaryScalarOp):
return math.tanh(x) return math.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * (1 - tanh(x)**2), return gz * (1 - tanh(x)**2),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals() return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh')
class Composite(ScalarOp):
def __init__(self, inputs, outputs):
env = Env(inputs, outputs).clone()
inputs, outputs = env.inputs, env.outputs
for node in env.nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError("The env to Composite must be exclusively composed of ScalarOp instances.")
subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) +
zip(outputs,
["%%(o%i)s"%i for i in range(len(outputs))]))
for orphan in env.orphans:
if isinstance(orphan, Constant):
subd[orphan] = orphan.type.c_literal(orphan.data)
else:
raise ValueError("All orphans in the env to Composite must be Constant instances.")
_c_code = "{\n"
i = 0
j = 0
for node in env.toposort():
j += 1
for output in node.outputs:
if output not in subd:
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
_c_code += "%s %s;\n" % (output.type.dtype_specs()[1], name)
_c_code += node.op.c_code(node.inputs,
"%(name)s",
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail = "%(fail)s",
id = "%%(id)s_%i" % j))
_c_code += "\n"
_c_code += "}\n"
def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have scalar ops
if r in env.inputs:
idx = env.inputs.index(r)
return lambda inputs: inputs[idx]
elif r in env.orphans:
return lambda inputs: r.data
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
return lambda inputs: node.op.impl(*[p(inputs) for p in producers])
_impls = [compose_impl(r) for r in env.outputs]
self._c_code = _c_code
self._impls = _impls
self.nin = len(inputs)
self.nout = len(outputs)
self.env = env
def output_types(self, input_types):
if tuple(input_types) != tuple([input.type for input in self.env.inputs]):
raise TypeError("Wrong types for Composite. Expected %s, got %s."
% (tuple([input.type for input in self.env.inputs]), tuple(input_types)))
return [output.type for output in self.env.outputs]
def perform(self, node, inputs, output_storage):
for storage, impl in zip(output_storage, self._impls):
storage[0] = impl(inputs)
def impl(self, *inputs):
output_storage = [[None] for i in xrange(self.nout)]
self.perform(None, inputs, output_storage)
return utils.to_return_values([storage[0] for storage in output_storage])
def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")
def c_code(self, node, name, inames, onames, sub):
d = dict(zip(["i%i"%i for i in range(len(inames))],
inames) +
zip(["o%i"%i for i in range(len(onames))],
onames),
**sub)
d['name'] = name
return self._c_code % d
#NOTE WELL!!!
# The following adds functions to this module automatically.
# For every scalar op class, a lower-case symbol is added which is a constructor
# for that class.
from gof import modes
modes.make_constructors(globals())
def composite(inputs, outputs):
"""
Usage: composite(inputs, outputs)
Produces an Op class which represents the computations
between the provided inputs and outputs as a single
operation.
# def as_scalar(x, name = None):
# if isinstance(x, gof.Op):
# if len(x.outputs) != 1:
# raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", x)
# else:
# x = x.outputs[0]
# if isinstance(x, float):
# s = Scalar('float64', name = name)
# s.data = x
# return s
# if isinstance(x, int):
# s = Scalar('int32', name = name)
# s.data = x
# return s
# if isinstance(x, Scalar):
# return x
# raise TypeError("Cannot convert %s to Scalar" % x)
# def constant(x):
# res = as_scalar(x)
# res.constant = True
# return res
# class Scalar(Result):
# def __init__(self, dtype, name = None):
# Result.__init__(self, role = None, name = name)
# self.dtype = dtype
# self.dtype_specs()
# def __get_constant(self):
# if not hasattr(self, '_constant'):
# return False
# return self._constant
# def __set_constant(self, value):
# if value:
# self.indestructible = True
# self._constant = value
# constant = property(__get_constant, __set_constant)
# def desc(self):
# return (self.dtype, self.data)
The operations between inputs and outputs (as given by # def filter(self, data):
Env(inputs, outputs).ops()) must all be instances of # py_type = self.dtype_specs()[0]
ScalarOp. # return py_type(data)
Examples: # def same_properties(self, other):
x, y = Scalar(), Scalar() # return other.dtype == self.dtype
SquareDiff = composite([x, y], [(x - y)**2])
TimesTen = composite([x], [x * 10.0]) # def dtype_specs(self):
Neighbors = composite([x], [x - 1, x + 1]) # try:
""" # return {'float32': (numpy.float32, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
# 'float64': (numpy.float64, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
# 'complex128': (numpy.complex128, 'theano_complex128', 'PyComplex_Check', 'PyComplex_AsCComplex', 'PyComplex_FromCComplex'),
# 'complex64': (numpy.complex64, 'theano_complex64', None, None, None),
# 'int8': (numpy.int8, 'npy_int8', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
# 'int16': (numpy.int16, 'npy_int16', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
# 'int32': (numpy.int32, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
# 'int64': (numpy.int64, 'npy_int64', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong')
# }[self.dtype]
# except KeyError:
# raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
# def c_literal(self):
# if 'complex' in self.dtype:
# raise NotImplementedError("No literal for complex values.")
# return str(self.data)
# def c_declare(self, name, sub):
# return """
# %(dtype)s %(name)s;
# typedef %(dtype)s %(name)s_dtype;
# """ % dict(name = name, dtype = self.dtype_specs()[1])
# def c_init(self, name, sub):
# return """
# %(name)s = 0;
# """ % locals()
env = Env(inputs, outputs).clone() # def c_extract(self, name, sub):
gof.opt.ConstantFinder().apply(env) # specs = self.dtype_specs()
# return """
# if (!%(check)s(py_%(name)s))
# %(fail)s
# %(name)s = (%(dtype)s)%(conv)s(py_%(name)s);
# """ % dict(sub,
# name = name,
# dtype = specs[1],
# check = specs[2],
# conv = specs[3])
inputs, outputs = env.inputs, env.outputs # def c_sync(self, name, sub):
# specs = self.dtype_specs()
# return """
# Py_XDECREF(py_%(name)s);
# py_%(name)s = %(conv)s((%(dtype)s)%(name)s);
# if (!py_%(name)s)
# py_%(name)s = Py_None;
# """ % dict(name = name,
# dtype = specs[1],
# conv = specs[4])
# def c_cleanup(self, name, sub):
# return ""
# def c_support_code(cls):
# template = """
# struct theano_complex%(nbits)s : public npy_complex%(nbits)s
# {
# typedef theano_complex%(nbits)s complex_type;
# typedef npy_float%(half_nbits)s scalar_type;
# complex_type operator +(complex_type y) {
# complex_type ret;
# ret.real = this->real + y.real;
# ret.imag = this->imag + y.imag;
# return ret;
# }
# complex_type operator -(complex_type y) {
# complex_type ret;
# ret.real = this->real - y.real;
# ret.imag = this->imag - y.imag;
# return ret;
# }
# complex_type operator *(complex_type y) {
# complex_type ret;
# ret.real = this->real * y.real - this->imag * y.imag;
# ret.imag = this->real * y.imag + this->imag * y.real;
# return ret;
# }
# complex_type operator /(complex_type y) {
# complex_type ret;
# scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
# ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
# ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
# return ret;
# }
# };
# """
# return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
# def __copy__(self):
# """Return a copy of this instance (with its own attributes)"""
# cpy = self.__class__(self.dtype, self.name)
# cpy.data = self.data
# return cpy
# #UNARY
# def __abs__(self): return Abs(self).out
# def __neg__(self): return Neg(self).out
# #CASTS
# def __int__(self): return AsInt(self).out
# def __float__(self): return AsInt(self).out
# def __complex__(self): return AsComplex(self).out
# #COMPARISONS
# def __lt__(self,other): return lt(self, other)
# def __le__(self,other): return le(self, other)
# def __gt__(self,other): return gt(self, other)
# def __ge__(self,other): return ge(self, other)
# #ARITHMETIC - NORMAL
# def __add__(self,other): return add(self,other)
# def __sub__(self,other): return sub(self,other)
# def __mul__(self,other): return mul(self,other)
# def __div__(self,other): return div(self,other)
# def __pow__(self,other): return pow(self,other)
# #ARITHMETIC - RIGHT-OPERAND
# def __radd__(self,other): return add(other,self)
# def __rsub__(self,other): return sub(other,self)
# def __rmul__(self,other): return mul(other,self)
# def __rdiv__(self,other): return div(other,self)
# def __rpow__(self,other): return pow(other,self)
# # Easy constructors
# def _multi(*fns):
# def f2(f, names):
# if len(names) == 1:
# return f(names)
# else:
# return [f(name) for name in names]
# if len(fns) == 1:
# return partial(f2, fns[0])
# else:
# return [partial(f2, f) for f in fns]
# def intr(name):
# return Scalar(name = name, dtype = 'int64')
# ints = _multi(intr)
# def floatr(name):
# return Scalar(name = name, dtype = 'float64')
# floats = _multi(floatr)
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# class ScalarOp(GuardedOp):
# nin = -1
# nout = 1
# def __init__(self, *inputs):
# if self.nin >= 0:
# if len(inputs) != self.nin:
# raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
# % (self.__class__.__name__, len(inputs), self.nin))
# else:
# self.nin = len(inputs)
# inputs = [as_scalar(input) for input in inputs]
# i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
# o_dtypes = self.output_dtypes(*i_dtypes)
for op in env.ops(): # self.inputs = inputs
if not isinstance(op, ScalarOp): # self.outputs = [Scalar(dtype) for dtype in o_dtypes]
raise ValueError("The input env to composite must be exclusively composed of ScalarOp instances.")
subd = dict(zip(inputs, # def output_dtypes(self, *dtypes):
["%%(i%i)s"%i for i in range(len(inputs))]) + # if self.nout != 1:
zip(outputs, # raise NotImplementedError()
["%%(o%i)s"%i for i in range(len(outputs))])) # return upcast(*dtypes),
# def impl(self, *inputs):
# raise AbstractFunctionError()
for orphan in env.orphans(): # def grad(self, inputs, output_gradients):
if orphan.constant: # raise AbstractFunctionError()
subd[orphan] = orphan.c_literal()
else: # def perform(self):
raise ValueError("All orphans in the input env to composite must be constant.") # if self.nout == 1:
# self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
_c_code = "{\n" # else:
i = 0 # results = utils.from_return_values(self.impl(*[input.data for input in self.inputs]))
j = 0 # for output, result in zip(self.outputs, results):
for op in env.toposort(): # output.data = result
j += 1
for output in op.outputs: # class UnaryScalarOp(ScalarOp):
if output not in subd: # nin = 1
i += 1
name = "V%%(id)s_tmp%i" % i # class BinaryScalarOp(ScalarOp):
subd[output] = name # nin = 2
_c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
_c_code += op.c_code([subd[input] for input in op.inputs], # class FloatUnaryScalarOp(UnaryScalarOp):
[subd[output] for output in op.outputs], # def output_dtypes(self, input_dtype):
dict(fail = "%(fail)s", # if 'int' in input_dtype: return 'float64',
id = "%%(id)s_%i" % j)) # if 'float' in input_dtype: return input_dtype,
_c_code += "\n" # raise NotImplementedError()
_c_code += "}\n"
# class Add(ScalarOp):
# identity = 0
# commutative = True
# associative = True
# def impl(self, *inputs):
# return sum(inputs)
# def c_code(self, inputs, (z, ), sub):
# if not inputs:
# return z + " = 0;"
# else:
# return z + " = " + " + ".join(inputs) + ";"
# def grad(self, inputs, (gz, )):
# return (gz, ) * len(inputs)
# class Mul(ScalarOp):
# identity = 1
# commutative = True
# associative = True
# def impl(self, *inputs):
# return numpy.product(inputs)
# def c_code(self, inputs, (z, ), sub):
# if not inputs:
# return z + " = 1;"
# else:
# return z + " = " + " * ".join(inputs) + ";"
# def grad(self, inputs, (gz, )):
# return [mul(*([gz] + utils.difference(inputs, [input])))
# for input in inputs]
# class Sub(BinaryScalarOp):
# def impl(self, x, y):
# return x - y
# def c_code(self, (x, y), (z, ), sub):
# return "%(z)s = %(x)s - %(y)s;" % locals()
# def grad(self, (x, y), (gz, )):
# return gz, -gz
# class Div(BinaryScalarOp):
# def impl(self, x, y):
# return x / y
# def c_code(self, (x, y), (z, ), sub):
# if 'int' in self.inputs[0].dtype and 'int' in self.inputs[1].dtype:
# raise NotImplementedError("For integer arguments the behavior of division in C and in Python differ when the quotient is negative (to implement).")
# return "%(z)s = %(x)s / %(y)s;" % locals()
# def grad(self, (x, y), (gz, )):
# return gz / y, -(gz * x) / (y * y)
# class Pow(BinaryScalarOp):
# def impl(self, x, y):
# return x ** y
# def c_code(self, (x, y), (z, ), sub):
# return "%(z)s = pow(%(x)s, %(y)s);" % locals()
# def grad(self, (x, y), (gz, )):
# return gz * y * x**(y - 1), gz * log(x) * x**y
# class First(BinaryScalarOp):
# def impl(self, x, y):
# return x
# def c_code(self, (x, y), (z, ), sub):
# return "%(z)s = %(x)s;" % locals()
# def grad(self, (x, y), (gz, )):
# return gz, None
# class Second(BinaryScalarOp):
# def impl(self, x, y):
# return y
# def c_code(self, (x, y), (z, ), sub):
# return "%(z)s = %(y)s;" % locals()
# def grad(self, (x, y), (gz, )):
# return None, gz
# class Identity(UnaryScalarOp):
# def impl(self, x):
# return x
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = %(x)s;" % locals()
# def grad(self, (x, ), (gz, )):
# return gz,
# class Neg(UnaryScalarOp):
# def impl(self, x):
# return -x
# def grad(self, (x, ), (gz, )):
# return -gz,
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = -%(x)s;" % locals()
# class Abs(UnaryScalarOp):
# #TODO: for complex input, output is some flavour of float
# def impl(self, x):
# return numpy.abs(x)
# def grad(self, (x, ), (gz, )):
# return gz * sgn(x),
# def c_code(self, (x, ), (z, ), sub):
# dtype = str(self.inputs[0].dtype)
# if 'int' in dtype:
# return "%(z)s = abs(%(x)s);" % locals()
# if 'float' in dtype:
# return "%(z)s = fabs(%(x)s);" % locals()
# #complex, other?
# raise NotImplementedError('dtype not supported', dtype)
# class Sgn(UnaryScalarOp):
# def impl(self, x):
# #casting to output type is handled by filter
# return numpy.sign(x)
# def grad(self, (x, ), (gz, )):
# return None,
# def c_code(self, (x, ), (z, ), sub):
# #casting is done by compiler
# #TODO: use copysign
# return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
# class Inv(FloatUnaryScalarOp):
# def impl(self, x):
# return 1.0 / x
# def grad(self, (x, ), (gz, )):
# return -gz / (x * x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = 1.0 / %(x)s;" % locals()
# class Log(FloatUnaryScalarOp):
# def impl(self, x):
# return math.log(x)
# def grad(self, (x, ), (gz, )):
# return gz / x,
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = log(%(x)s);" % locals()
# class Log2(FloatUnaryScalarOp):
# def impl(self, x):
# return numpy.log2(x)
# def grad(self, (x, ), (gz, )):
# return gz / (x * math.log(2.0)),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = log2(%(x)s);" % locals()
# class Exp(FloatUnaryScalarOp):
# def impl(self, x):
# return math.exp(x)
# def grad(self, (x, ), (gz, )):
# return gz * exp(x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = exp(%(x)s);" % locals()
# class Sqr(UnaryScalarOp):
# def impl(self, x):
# return x*x
# def grad(self, (x, ), (gz, )):
# return gz * x * 2,
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = %(x)s * %(x)s;" % locals()
# class Sqrt(FloatUnaryScalarOp):
# def impl(self, x):
# return math.sqrt(x)
# def grad(self, (x, ), (gz, )):
# return (gz * 0.5) / sqrt(x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = sqrt(%(x)s);" % locals()
# class Cos(FloatUnaryScalarOp):
# def impl(self, x):
# return math.cos(x)
# def grad(self, (x, ), (gz, )):
# return -gz * sin(x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = cos(%(x)s);" % locals()
# class Sin(FloatUnaryScalarOp):
# def impl(self, x):
# return math.sin(x)
# def grad(self, (x, ), (gz, )):
# return gz * cos(x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = sin(%(x)s);" % locals()
# class Tan(FloatUnaryScalarOp):
# def impl(self, x):
# return math.tan(x)
# def grad(self, (x, ), (gz, )):
# return gz / (cos(x) ** 2),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = tan(%(x)s);" % locals()
# class Cosh(FloatUnaryScalarOp):
# """
# sinh(x) = (exp(x) + exp(-x)) / 2
# """
# def impl(self, x):
# return math.cosh(x)
# def grad(self, (x, ), (gz, )):
# return gz * sinh(x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = cosh(%(x)s);" % locals()
# class Sinh(FloatUnaryScalarOp):
# """
# sinh(x) = (exp(x) - exp(-x)) / 2
# """
# def impl(self, x):
# return math.sinh(x)
# def grad(self, (x, ), (gz, )):
# return gz * cosh(x),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = sinh(%(x)s);" % locals()
# class Tanh(FloatUnaryScalarOp):
# """
# tanh(x) = sinh(x) / cosh(x)
# = (exp(2*x) - 1) / (exp(2*x) + 1)
# """
# def impl(self, x):
# return math.tanh(x)
# def grad(self, (x, ), (gz, )):
# return gz * (1 - tanh(x)**2),
# def c_code(self, (x, ), (z, ), sub):
# return "%(z)s = tanh(%(x)s);" % locals()
# #NOTE WELL!!!
# # The following adds functions to this module automatically.
# # For every scalar op class, a lower-case symbol is added which is a constructor
# # for that class.
# from gof import modes
# modes.make_constructors(globals())
# def composite(inputs, outputs):
# """
# Usage: composite(inputs, outputs)
# Produces an Op class which represents the computations
# between the provided inputs and outputs as a single
# operation.
# The operations between inputs and outputs (as given by
# Env(inputs, outputs).ops()) must all be instances of
# ScalarOp.
# Examples:
# x, y = Scalar(), Scalar()
# SquareDiff = composite([x, y], [(x - y)**2])
# TimesTen = composite([x], [x * 10.0])
# Neighbors = composite([x], [x - 1, x + 1])
# """
# env = Env(inputs, outputs).clone()
# gof.opt.ConstantFinder().apply(env)
# inputs, outputs = env.inputs, env.outputs
# for op in env.ops():
# if not isinstance(op, ScalarOp):
# raise ValueError("The input env to composite must be exclusively composed of ScalarOp instances.")
# subd = dict(zip(inputs,
# ["%%(i%i)s"%i for i in range(len(inputs))]) +
# zip(outputs,
# ["%%(o%i)s"%i for i in range(len(outputs))]))
# for orphan in env.orphans():
# if orphan.constant:
# subd[orphan] = orphan.c_literal()
# else:
# raise ValueError("All orphans in the input env to composite must be constant.")
# _c_code = "{\n"
# i = 0
# j = 0
# for op in env.toposort():
# j += 1
# for output in op.outputs:
# if output not in subd:
# i += 1
# name = "V%%(id)s_tmp%i" % i
# subd[output] = name
# _c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
# _c_code += op.c_code([subd[input] for input in op.inputs],
# [subd[output] for output in op.outputs],
# dict(fail = "%(fail)s",
# id = "%%(id)s_%i" % j))
# _c_code += "\n"
# _c_code += "}\n"
def compose_impl(r): # def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1) # # this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice # # it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably) # # it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have scalar ops # # still correct since we only have scalar ops
if r in env.inputs: # if r in env.inputs:
idx = env.inputs.index(r) # idx = env.inputs.index(r)
return lambda inputs: inputs[idx] # return lambda inputs: inputs[idx]
elif r in env.orphans(): # elif r in env.orphans():
return lambda inputs: r.data # return lambda inputs: r.data
op = r.owner # op = r.owner
producers = [compose_impl(input) for input in op.inputs] # producers = [compose_impl(input) for input in op.inputs]
return lambda inputs: op.impl(*[p(inputs) for p in producers]) # return lambda inputs: op.impl(*[p(inputs) for p in producers])
_impls = [compose_impl(r) for r in env.outputs] # _impls = [compose_impl(r) for r in env.outputs]
class Composite(ScalarOp): # class Composite(ScalarOp):
nin = len(inputs) # nin = len(inputs)
nout = len(outputs) # nout = len(outputs)
def output_dtypes(self, *input_dtypes): # def output_dtypes(self, *input_dtypes):
assert input_dtypes == tuple([input.dtype for input in inputs]) # assert input_dtypes == tuple([input.dtype for input in inputs])
return [output.dtype for dtype in outputs] # return [output.dtype for dtype in outputs]
def perform(self): # def perform(self):
inputs = [input.data for input in self.inputs] # inputs = [input.data for input in self.inputs]
for output, impl in zip(self.outputs, _impls): # for output, impl in zip(self.outputs, _impls):
output.data = impl(inputs) # output.data = impl(inputs)
def impl(self, *inputs): # def impl(self, *inputs):
for r, input in zip(self.inputs, inputs): # for r, input in zip(self.inputs, inputs):
r.data = input # r.data = input
self.perform() # self.perform()
return utils.to_return_values([output.data for output in self.outputs]) # return utils.to_return_values([output.data for output in self.outputs])
def grad(self, inputs, output_grads): # def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite") # raise NotImplementedError("grad is not implemented for Composite")
def c_code(self, inames, onames, sub): # def c_code(self, inames, onames, sub):
d = dict(zip(["i%i"%i for i in range(len(inames))], # d = dict(zip(["i%i"%i for i in range(len(inames))],
inames) + # inames) +
zip(["o%i"%i for i in range(len(onames))], # zip(["o%i"%i for i in range(len(onames))],
onames), # onames),
**sub) # **sub)
return _c_code % d # return _c_code % d
return Composite # return Composite
...@@ -6,20 +6,20 @@ from gof import utils ...@@ -6,20 +6,20 @@ from gof import utils
C = constant C = constant
# x**2 -> x*x # x**2 -> x*x
pow2sqr_float = Pattern((Pow, 'x', C(2.0)), (Sqr, 'x')) pow2sqr_float = Pattern((pow, 'x', C(2.0)), (sqr, 'x'))
pow2sqr_int = Pattern((Pow, 'x', C(2)), (Sqr, 'x')) pow2sqr_int = Pattern((pow, 'x', C(2)), (sqr, 'x'))
# x**0 -> 1 # x**0 -> 1
pow2one_float = Pattern((Pow, 'x', C(0.0)), C(1.0)) pow2one_float = Pattern((pow, 'x', C(0.0)), C(1.0))
pow2one_int = Pattern((Pow, 'x', C(0)), C(1)) pow2one_int = Pattern((pow, 'x', C(0)), C(1))
# x**1 -> x # x**1 -> x
pow2x_float = Pattern((Pow, 'x', C(1.0)), 'x') pow2x_float = Pattern((pow, 'x', C(1.0)), 'x')
pow2x_int = Pattern((Pow, 'x', C(1)), 'x') pow2x_int = Pattern((pow, 'x', C(1)), 'x')
# log(x**y) -> y*log(x) # log(x**y) -> y*log(x)
logpow = Pattern((Log, (Pow, 'x', 'y')), logpow = Pattern((log, (pow, 'x', 'y')),
(Mul, 'y', (Log, 'x'))) (mul, 'y', (log, 'x')))
class Canonizer(gof.Optimizer): class Canonizer(gof.Optimizer):
...@@ -71,8 +71,6 @@ class Canonizer(gof.Optimizer): ...@@ -71,8 +71,6 @@ class Canonizer(gof.Optimizer):
def canonize(r): def canonize(r):
# if r in env.inputs or r in env.orphans():
# return
next = env.follow(r) next = env.follow(r)
if next is None: if next is None:
return return
...@@ -84,19 +82,20 @@ class Canonizer(gof.Optimizer): ...@@ -84,19 +82,20 @@ class Canonizer(gof.Optimizer):
if env.edge(r): if env.edge(r):
return [r], [] return [r], []
op = r.owner print "a", r, r.owner, env, env.orphans
# if op is None or r in env.inputs or r in env.orphans(): node = r.owner
# return [r], [] op = node.op
print "b"
results = [r2.dtype == r.dtype and flatten(r2) or ([r2], []) for r2 in op.inputs] results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
if isinstance(op, self.main) and (not nclients_check or env.nclients(r) == 1): if op == self.main and (not nclients_check or env.nclients(r) == 1):
nums = [x[0] for x in results] nums = [x[0] for x in results]
denums = [x[1] for x in results] denums = [x[1] for x in results]
elif isinstance(op, self.inverse) and (not nclients_check or env.nclients(r) == 1): elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively # num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]] nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]] denums = [results[0][1], results[1][0]]
elif isinstance(op, self.reciprocal) and (not nclients_check or env.nclients(r) == 1): elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively # num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]] nums = [results[0][1]]
denums = [results[0][0]] denums = [results[0][0]]
...@@ -111,12 +110,6 @@ class Canonizer(gof.Optimizer): ...@@ -111,12 +110,6 @@ class Canonizer(gof.Optimizer):
for input in (env.follow(r) or []): for input in (env.follow(r) or []):
canonize(input) canonize(input)
return return
# if r.owner is None:
# return
# else:
# for input in r.owner.inputs:
# canonize(input)
# return
# Terms that are both in the num and denum lists cancel each other # Terms that are both in the num and denum lists cancel each other
for d in list(denum): for d in list(denum):
...@@ -126,8 +119,8 @@ class Canonizer(gof.Optimizer): ...@@ -126,8 +119,8 @@ class Canonizer(gof.Optimizer):
denum.remove(d) denum.remove(d)
# We identify the constants in num and denum # We identify the constants in num and denum
numct, num = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, num) numct, num = utils.partition(lambda factor: isinstance(factor, Constant) and factor.data is not None, num)
denumct, denum = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, denum) denumct, denum = utils.partition(lambda factor: isinstance(factor, Constant) and factor.data is not None, denum)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant) # All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct])) v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
...@@ -147,7 +140,7 @@ class Canonizer(gof.Optimizer): ...@@ -147,7 +140,7 @@ class Canonizer(gof.Optimizer):
elif n == 1: elif n == 1:
return factors[0] return factors[0]
else: else:
return self.main(*factors).out return self.main(*factors)
numr, denumr = make(num), make(denum) numr, denumr = make(num), make(denum)
...@@ -155,17 +148,15 @@ class Canonizer(gof.Optimizer): ...@@ -155,17 +148,15 @@ class Canonizer(gof.Optimizer):
if denumr is None: if denumr is None:
# Everything cancelled each other so we're left with # Everything cancelled each other so we're left with
# the neutral element. # the neutral element.
new_r = Scalar(dtype = r.dtype) new_r = Constant(r.type, self.neutral)
new_r.constant = True
new_r.data = self.neutral
else: else:
# There's no numerator so we use reciprocal # There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr).out new_r = self.reciprocal(denumr)
else: else:
if denumr is None: if denumr is None:
new_r = numr new_r = numr
else: else:
new_r = self.inverse(numr, denumr).out new_r = self.inverse(numr, denumr)
# Hopefully this won't complain! # Hopefully this won't complain!
env.replace(r, new_r) env.replace(r, new_r)
...@@ -191,7 +182,6 @@ def group_powers(env, num, denum): ...@@ -191,7 +182,6 @@ def group_powers(env, num, denum):
Examples: Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], [] group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
""" """
# maps a base to the list of powers it is raised to in the # maps a base to the list of powers it is raised to in the
# numerator/denominator lists. # numerator/denominator lists.
num_powers = {} num_powers = {}
...@@ -201,14 +191,15 @@ def group_powers(env, num, denum): ...@@ -201,14 +191,15 @@ def group_powers(env, num, denum):
# For each instance of exp or pow in seq, removes it from seq # For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power). # and does d[base].append(power).
for factor in list(seq): for factor in list(seq):
op = factor.owner
if env.edge(factor): if env.edge(factor):
continue continue
if isinstance(op, Exp): node = factor.owner
d.setdefault('e', []).append(op.inputs[0]) op = node.op
if op == exp:
d.setdefault('e', []).append(node.inputs[0])
seq.remove(factor) seq.remove(factor)
elif isinstance(op, Pow): elif op == pow:
d.setdefault(op.inputs[0], []).append(op.inputs[1]) d.setdefault(node.inputs[0], []).append(node.inputs[1])
seq.remove(factor) seq.remove(factor)
populate(num_powers, num) populate(num_powers, num)
......
...@@ -6,9 +6,8 @@ import numpy ...@@ -6,9 +6,8 @@ import numpy
from copy import copy from copy import copy
from gof import Result, Op, utils, Destroyer, Viewer, AbstractFunctionError from gof import Result, Op, utils, Destroyer, Viewer, AbstractFunctionError, Type, Result, Constant, Apply
import gof.result import gof
import gof.op
import blas # for gemm, dot import blas # for gemm, dot
...@@ -18,76 +17,73 @@ import scalar as scal ...@@ -18,76 +17,73 @@ import scalar as scal
from functools import partial from functools import partial
class Tensor(Result): def as_tensor(x, name = None):
""" if isinstance(x, gof.Apply):
L{Result} to store L{numpy.ndarray} or equivalent via .data if len(x.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", x)
else:
x = x.outputs[0]
if isinstance(x, Result):
if not isinstance(x.type, Tensor):
raise TypeError("Result type field must be a Tensor.", x, x.type)
return x
if isinstance(x, Constant):
if not isinstance(x.type, Tensor):
raise TypeError("Constant type field must be a Tensor.", x, x.type)
return x
try:
return constant(x)
except TypeError:
raise
raise TypeError("Cannot convert %s to Tensor" % x, type(x))
This class does not implement python operators and has no dependencies
on the L{Op}s that use it. def constant(x):
if not isinstance(x, numpy.ndarray):
x = numpy.asarray(x)
try:
return TensorConstant(Tensor(dtype = x.dtype,
broadcastable = [d == 1 for d in x.shape]), x)
except:
raise
raise TypeError("Could not convert %s to Tensor" % _x, type(_x))
class Tensor(Type):
"""
L{Type} representing L{numpy.ndarray} in Theano.
@todo: At some point we should document a glossary, such as terms like @todo: At some point we should document a glossary, such as terms like
broadcasting and shape. broadcasting and shape.
@type _dtype: numpy dtype string such as 'int64' or 'float64' (among others) @type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
@type _broadcastable: tuple or list or array of boolean values, whose length @type broadcastable: tuple or list or array of boolean values, whose length
is the number of dimensions of the contained L{ndarray}. is the number of dimensions of the L{ndarray} represented by this Type.
@ivar _broadcastable: Each element of the broadcastable vector tells us @ivar broadcastable: Each element of the broadcastable vector tells us
something about the corresponding dimension: something about the corresponding dimension:
- False means the dimension can be anything. - False means the dimension can be anything.
- True means the dimension must be 1. Also, this dimension will be considered - True means the dimension must be 1. Also, this dimension will be considered
for L{broadcasting}, as described and implemented in Numpy. for L{broadcasting}, as described and implemented in Numpy.
""" """
def __init__(self, dtype, broadcastable, name=None): def __init__(self, dtype, broadcastable):
"""Initialize a L{Tensor} self.dtype = str(dtype)
self.broadcastable = broadcastable
@note: This does not actually allocate any data. self.dtype_specs() # error checking is done there
"""
def filter(self, data, strict = False):
_data = data
if strict:
assert isinstance(data, numpy.ndarray)
assert str(data.dtype) == self.dtype
else:
data = numpy.asarray(data, dtype = self.dtype)
if not self.ndim == data.ndim:
raise TypeError("Wrong number of dimensions: expected %s, got %s." % (self.ndim, data.ndim), _data)
if any(b and d != 1 for d, b in zip(data.shape, self.broadcastable)):
raise ValueError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable)
return data
# data is not given here. This may seem a bit strange, but when data was
# an argument, it made sense to use *either* the given dtype,
# broadcastable, or override them from the fields of data. This makes
# the function ugly, especially because it isn't obvious how to set
# broadcastable from data.
#
# The only clean option I could think of, when passing a data arg was to
# require the broadcastable field to be given. Since broadcastable is
# the argument that is awkward to construct, I decided to put all this
# into the tensor(data,...) function below, which is like a second
# constructor that works with an ndarray.
Result.__init__(self, role=None, name=name)
self._dtype = str(dtype)
self.dtype_specs() # this is just for error checking
self._broadcastable = tuple(broadcastable)
######################
# Result interface
######################
#
# filter
#
def filter(self, arr):
"""Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape."""
if not (isinstance(arr, numpy.ndarray) \
and arr.dtype==self.dtype):
arr = numpy.asarray(arr, dtype = self.dtype)
if len(self.broadcastable) != len(arr.shape):
raise ValueError(Tensor.filter.E_rank,
self.broadcastable,
arr.shape,
self.owner)
for b, s in zip(self.broadcastable, arr.shape):
if b and (s != 1):
raise ValueError(Tensor.filter.E_shape)
return arr
# these strings are here so that tests can use them
filter.E_rank = 'wrong rank'
filter.E_shape = 'non-unit size on broadcastable dimension'
#
# type information
#
def dtype_specs(self): def dtype_specs(self):
"""Return python - C type correspondance tuple for self.data """Return python - C type correspondance tuple for self.data
...@@ -108,21 +104,23 @@ class Tensor(Result): ...@@ -108,21 +104,23 @@ class Tensor(Result):
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
# def __eq__(self, other):
# Description for constant folding return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable
#
def desc(self): def __hash__(self):
""" return hash(self.dtype) ^ hash(self.broadcastable)
Returns a hashable description of this L{Tensor}.
""" ndim = property(lambda self: len(self.broadcastable), doc = "read-only access to the number of dimensions")
if self.data is not None:
return (Tensor, self.dtype, self.broadcastable, self.data.data[:]) def make_result(self, name = None):
else: return TensorResult(self, name = name)
return (Tensor, self.dtype, self.broadcastable, None)
def __str__(self):
# return "%s(%s)" % (str(self.dtype), str(self.broadcastable))
# C codegen stubs
# def __repr__(self):
return "Tensor{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub): def c_declare(self, name, sub):
return """ return """
PyArrayObject* %(name)s; PyArrayObject* %(name)s;
...@@ -226,35 +224,69 @@ class Tensor(Result): ...@@ -226,35 +224,69 @@ class Tensor(Result):
# todo: use C templating # todo: use C templating
############################ # Easy constructors
# Tensor specific attributes
############################
dtype = property(lambda self: self._dtype, doc = "read-only access to _dtype, which should not be changed")
broadcastable = property(lambda self: self._broadcastable, doc = "read-only access to _broadcastable, which should not be changed")
ndim = property(lambda self: len(self.broadcastable), doc = "read-only access to the number of dimensions")
############################
# Cloning facilities
############################
def __copy__(self): def _multi(*fns):
return self.clone(True) def f2(f, names):
if len(names) == 1:
def clone(self, transfer_data = False): return f(names)
"""Return a copy of this instance (with its own attributes) else:
return [f(name) for name in names]
If transfer_data is True, a copy of self.data is assigned to the copy's if len(fns) == 1:
data property, otherwise the copy's data is left as None. return partial(f2, fns)
""" else:
cpy = self.__class__(self.dtype, self.broadcastable, self.name) return [partial(f2, f) for f in fns]
if transfer_data:
cpy.data = copy(self.data)
return cpy
fscalar = Tensor('float32', ())
dscalar = Tensor('float64', ())
iscalar = Tensor('int32', ())
lscalar = Tensor('int64', ())
def scalar(name = None, dtype = 'float64'):
type = Tensor(dtype, ())
return type(name)
scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscalar, iscalar, lscalar)
fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, ))
ivector = Tensor('int32', (False, ))
lvector = Tensor('int64', (False, ))
def vector(name = None, dtype = 'float64'):
type = Tensor(dtype, (False, ))
return type(name)
vectors, fvectors, dvectors, ivectors, lvectors = _multi(vector, fvector, dvector, ivector, lvector)
fmatrix = Tensor('float32', (False, False))
dmatrix = Tensor('float64', (False, False))
imatrix = Tensor('int32', (False, False))
lmatrix = Tensor('int64', (False, False))
def matrix(name = None, dtype = 'float64'):
type = Tensor(dtype, (False, False))
return type(name)
matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(matrix, fmatrix, dmatrix, imatrix, lmatrix)
frow = Tensor('float32', (True, False))
drow = Tensor('float64', (True, False))
irow = Tensor('int32', (True, False))
lrow = Tensor('int64', (True, False))
def row(name = None, dtype = 'float64'):
type = Tensor(dtype, (True, False))
return type(name)
rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow)
fcol = Tensor('float32', (False, True))
dcol = Tensor('float64', (False, True))
icol = Tensor('int32', (False, True))
lcol = Tensor('int64', (False, True))
def col(name = None, dtype = 'float64'):
type = Tensor(dtype, (False, True))
return type(name)
cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
class _tensor_py_operators:
#UNARY #UNARY
def __abs__(self): return Abs(self).out def __abs__(self): return _abs(self)
def __neg__(self): return Neg(self).out def __neg__(self): return neg(self)
#CASTS #CASTS
def __int__(self): return AsInt(self).out def __int__(self): return AsInt(self).out
...@@ -297,86 +329,17 @@ class Tensor(Result): ...@@ -297,86 +329,17 @@ class Tensor(Result):
#COPYING #COPYING
def copy(self): return tensor_copy(self) def copy(self): return tensor_copy(self)
s2t.Tensor = Tensor
# alternate Tensor constructor
def astensor(data, broadcastable=None, name=None):
"""Return a L{Tensor} containing given data"""
if isinstance(data, Op):
if len(data.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", data)
else:
data = data.outputs[0]
if isinstance(data, Tensor):
if broadcastable is not None and list(data.broadcastable) != list(broadcastable):
raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable))
if name is not None and name != data.name:
raise ValueError("Cannot rename an existing Tensor.")
return data
elif isinstance(data, Result):
raise TypeError("Cannot make a Tensor out of a result that is not an instance of Tensor: %s (%s)" % (data, data.__class__.__name__), data)
if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None.")
_data = data
data = numpy.asarray(data)
if broadcastable is None:
broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape)
try:
rval = Tensor(data.dtype, broadcastable, name = name)
except TypeError:
raise TypeError("Cannot convert %s to Tensor." % repr(_data))
rval.data = data # will raise if broadcastable was mis-specified
return rval
s2t.astensor = astensor
# Easy constructors
def _multi(*fns):
def f2(f, names):
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
else:
return [partial(f2, f) for f in fns]
def _int_float(f): class TensorResult(Result, _tensor_py_operators):
return partial(f, dtype = 'int64'), partial(f, dtype = 'float64') pass
def scalar(name, dtype = 'float64'): class TensorConstant(Constant, _tensor_py_operators):
return Tensor(name = name, dtype = dtype, broadcastable = ()) pass
iscalar, fscalar = _int_float(scalar)
scalars, iscalars, fscalars = _multi(scalar, iscalar, fscalar)
def vector(name, dtype = 'float64'): s2t.as_tensor = as_tensor
return Tensor(name = name, dtype = dtype, broadcastable = (False)) s2t.Tensor = Tensor
ivector, fvector = _int_float(vector) s2t.TensorResult = TensorResult
vectors, ivectors, fvectors = _multi(vector, ivector, fvector) s2t.TensorConstant = TensorConstant
def matrix(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (False, False))
imatrix, fmatrix = _int_float(matrix)
matrices, imatrices, fmatrices = _multi(matrix, imatrix, fmatrix)
def row(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (True, False))
irow, frow = _int_float(row)
rows, irows, frows = _multi(row, irow, frow)
def col(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (False, True))
icol, fcol = _int_float(col)
cols, icols, fcols = _multi(col, icol, fcol)
...@@ -386,58 +349,59 @@ cols, icols, fcols = _multi(col, icol, fcol) ...@@ -386,58 +349,59 @@ cols, icols, fcols = _multi(col, icol, fcol)
# this has a different name, because _as_tensor is the function which ops use # this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor. # to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor = astensor _as_tensor = as_tensor
class _Op(Op):
""" # class _Op(Op):
A basic L{Op} subclass that can be used to make L{Op}s that operate on L{Tensor}s. # """
It is not mandatory to inherit from this class, but it is practical. # A basic L{Op} subclass that can be used to make L{Op}s that operate on L{Tensor}s.
# It is not mandatory to inherit from this class, but it is practical.
@ivar nin: number of inputs
@ivar nout: number of outputs # @ivar nin: number of inputs
@ivar out_tensor_class: L{Tensor} subclass used to instantiate the outputs # @ivar nout: number of outputs
# @ivar out_tensor_class: L{Tensor} subclass used to instantiate the outputs
- input_wrapper: returns a L{Tensor} from its argument
- propagate_dtype: returns a list of dtypes corresponding to the # - input_wrapper: returns a L{Tensor} from its argument
output dtypes from a list of input dtypes (if an input is not a # - propagate_dtype: returns a list of dtypes corresponding to the
L{Tensor}, the passed value will be None) # output dtypes from a list of input dtypes (if an input is not a
- propagate_broadcastable: returns a list of tuples corresponding # L{Tensor}, the passed value will be None)
to the output broadcastable flags from the input broadcastable flags # - propagate_broadcastable: returns a list of tuples corresponding
(if an input is not a L{Tensor}, the passed value will be None). # to the output broadcastable flags from the input broadcastable flags
""" # (if an input is not a L{Tensor}, the passed value will be None).
# """
nin = -1 # nin == -1 means: arbitrary number of inputs # nin = -1 # nin == -1 means: arbitrary number of inputs
nout = 1 # nout = 1
def __init__(self, *inputs): # def __init__(self, *inputs):
inputs = map(_as_tensor, inputs) # inputs = map(_as_tensor, inputs)
if self.nin >= 0: # if self.nin >= 0:
if len(inputs) != self.nin: # if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \ # raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \
% (self, len(inputs), self.nin) # % (self, len(inputs), self.nin)
i_broadcastables = [getattr(input, 'broadcastable', None) for input in inputs] # i_broadcastables = [getattr(input, 'broadcastable', None) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] # i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_broadcastables = utils.from_return_values(self.propagate_broadcastable(*i_broadcastables)) # o_broadcastables = utils.from_return_values(self.propagate_broadcastable(*i_broadcastables))
o_dtypes = utils.from_return_values(self.propagate_dtype(*i_dtypes)) # o_dtypes = utils.from_return_values(self.propagate_dtype(*i_dtypes))
self.inputs = inputs # self.inputs = inputs
self.outputs = [Tensor(dtype, broadcastable) for broadcastable, dtype in zip(o_broadcastables, o_dtypes)] # self.outputs = [Tensor(dtype, broadcastable) for broadcastable, dtype in zip(o_broadcastables, o_dtypes)]
def propagate_broadcastable(self, *inputs): # def propagate_broadcastable(self, *inputs):
raise AbstractFunctionError() # raise AbstractFunctionError()
def propagate_dtype(self, *i_dtypes): # def propagate_dtype(self, *i_dtypes):
rval = set([dtype for dtype in i_dtypes if dtype is not None]) # rval = set([dtype for dtype in i_dtypes if dtype is not None])
if len(rval) == 0: # if len(rval) == 0:
raise ValueError("Cannot infer the dtypes of the outputs with no Tensor inputs.") # raise ValueError("Cannot infer the dtypes of the outputs with no Tensor inputs.")
elif len(rval) > 1: # elif len(rval) > 1:
raise ValueError("The dtypes of all inputs should be identical.") # raise ValueError("The dtypes of all inputs should be identical.")
return [rval.pop()] * self.nout # return [rval.pop()] * self.nout
...@@ -445,109 +409,116 @@ class _Op(Op): ...@@ -445,109 +409,116 @@ class _Op(Op):
# Unary Operations # Unary Operations
########################## ##########################
def broadcast(scalar_opclass, name, module_name = None, inplace_versions = True): # def broadcast(scalar_opclass, name, module_name = None, inplace_versions = True):
C = s2t.make_broadcast(scalar_opclass, name = name, module_name = module_name) # this returns a class # C = s2t.make_broadcast(scalar_opclass, name = name, module_name = module_name) # this returns a class
C.__module__ = module_name # C.__module__ = module_name
c = gof.op.constructor(s2t.wrap_broadcast(C)) # c = gof.op.constructor(s2t.wrap_broadcast(C))
if inplace_versions: # if inplace_versions:
CInplace = s2t.make_broadcast(scalar_opclass, {0:0}, name = name+"Inplace") # CInplace = s2t.make_broadcast(scalar_opclass, {0:0}, name = name+"Inplace")
CInplace.__module__ = module_name # CInplace.__module__ = module_name
c_inplace = gof.op.constructor(s2t.wrap_broadcast(CInplace)) # c_inplace = gof.op.constructor(s2t.wrap_broadcast(CInplace))
return C, c, CInplace, c_inplace # return C, c, CInplace, c_inplace
else: # else:
return C, c # return C, c
def _broadcast(scalar_opclass, name, inplace_versions = True): # def _broadcast(scalar_opclass, name, inplace_versions = True):
return broadcast(scalar_opclass, name, 'tensor', inplace_versions) # return broadcast(scalar_opclass, name, 'tensor', inplace_versions)
class Argmax(Op): # class Argmax(Op):
"""Calculate the max and argmax over a given axis""" # """Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis # nin=2 # tensor, axis
nout=2 # max val, max idx # nout=2 # max val, max idx
E_axis = 'invalid axis' # E_axis = 'invalid axis'
debug = 0 # debug = 0
def __init__(self, x, axis=None): # def __init__(self, x, axis=None):
x = _as_tensor(x) # x = _as_tensor(x)
if axis is None: # if axis is None:
axis = len(x.broadcastable) -1 # axis = len(x.broadcastable) -1
axis = _as_tensor(axis) # axis = _as_tensor(axis)
self.inputs = [x, axis] # self.inputs = [x, axis]
broadcastable = [0] * (len(x.broadcastable) - 1) # broadcastable = [0] * (len(x.broadcastable) - 1)
self.outputs = [Tensor(x.dtype, broadcastable), # self.outputs = [Tensor(x.dtype, broadcastable),
Tensor(axis.dtype, broadcastable)] # Tensor(axis.dtype, broadcastable)]
def perform(self): # def perform(self):
axis = self.inputs[1].data # axis = self.inputs[1].data
x = self.inputs[0].data # x = self.inputs[0].data
self.outputs[0].data = numpy.max(x, axis) # self.outputs[0].data = numpy.max(x, axis)
self.outputs[1].data = numpy.argmax(x,axis) # self.outputs[1].data = numpy.argmax(x,axis)
argmax = gof.op.constructor(Argmax) # argmax = gof.op.constructor(Argmax)
def max(x, axis=None): # def max(x, axis=None):
"""Return maximum elements obtained by iterating over given axis # """Return maximum elements obtained by iterating over given axis
Default axis is the last one. # Default axis is the last one.
""" # """
# In python (using Argmax.perform()) this leads to an wasteful # # In python (using Argmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once # # implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine. # # but when Argmax.c_impl() is in place, it should be fine.
return argmax(x,axis)[0] # return argmax(x,axis)[0]
Abs, _abs, AbsInplace, abs_inplace = _broadcast(scal.Abs, 'Abs')
Exp, exp, ExpInplace, exp_inplace = _broadcast(scal.Exp, 'Exp') def _elemwise(scalar_op, name):
Neg, neg, NegInplace, neg_inplace = _broadcast(scal.Neg, 'Neg') straight = s2t.Elemwise(scalar_op)
Log, log, LogInplace, log_inplace = _broadcast(scal.Log, 'Log') inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
Log2, log2, Log2Inplace, log2_inplace = _broadcast(scal.Log2, 'Log2') inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
Sgn, sgn, SgnInplace, sgn_inplace = _broadcast(scal.Sgn, 'Sgn') return straight, inplace
Sqr, sqr, SqrInplace, sqr_inplace = _broadcast(scal.Sqr, 'Sqr')
Sqrt, sqrt, SqrtInplace, sqrt_inplace = _broadcast(scal.Sqrt, 'Sqrt') _abs, abs_inplace = _elemwise(scal.abs, 'abs')
Cos, cos, CosInplace, cos_inplace = _broadcast(scal.Cos, 'Cos') exp, exp_inplace = _elemwise(scal.exp, 'exp')
Sin, sin, SinInplace, sin_inplace = _broadcast(scal.Sin, 'Sin') neg, neg_inplace = _elemwise(scal.neg, 'neg')
Tan, tan, TanInplace, tan_inplace = _broadcast(scal.Tan, 'Tan') log, log_inplace = _elemwise(scal.log, 'log')
Cosh, cosh, CoshInplace, cosh_inplace = _broadcast(scal.Cosh, 'Cosh') log2, log2_inplace = _elemwise(scal.log2, 'log2')
Sinh, sinh, SinhInplace, sinh_inplace = _broadcast(scal.Sinh, 'Sinh') sgn, sgn_inplace = _elemwise(scal.sgn, 'sgn')
Tanh, tanh, TanhInplace, tanh_inplace = _broadcast(scal.Tanh, 'Tanh') sqr, sqr_inplace = _elemwise(scal.sqr, 'sqr')
sqrt, sqrt_inplace = _elemwise(scal.sqrt, 'sqrt')
Fill, fill, FillInplace, fill_inplace = _broadcast(scal.Second, 'Fill') cos, cos_inplace = _elemwise(scal.cos, 'cos')
sin, sin_inplace = _elemwise(scal.sin, 'sin')
tan, tan_inplace = _elemwise(scal.tan, 'tan')
cosh, cosh_inplace = _elemwise(scal.cosh, 'cosh')
sinh, sinh_inplace = _elemwise(scal.sinh, 'sinh')
tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
fill, fill_inplace = _elemwise(scal.second, 'fill')
def ones_like(model): def ones_like(model):
return fill(model, 1.0) return fill(model, 1.0)
def zeros_like(model): def zeros_like(model):
return fill(model, 0.0) return fill(model, 0.0)
TensorCopy, tensor_copy = _broadcast(scal.Identity, 'TensorCopy', inplace_versions = False) tensor_copy = s2t.Elemwise(scal.identity)
Sum = s2t.Sum def sum(input, axis = None):
sum = gof.op.constructor(Sum) return s2t.Sum(axis)(input)
########################## ##########################
# Arithmetics # Arithmetics
########################## ##########################
Add, add, AddInplace, add_inplace = _broadcast(scal.Add, 'Add') add, add_inplace = _elemwise(scal.add, 'add')
Sub, sub, SubInplace, sub_inplace = _broadcast(scal.Sub, 'Sub') sub, sub_inplace = _elemwise(scal.sub, 'sub')
Mul, mul, MulInplace, mul_inplace = _broadcast(scal.Mul, 'Mul') mul, mul_inplace = _elemwise(scal.mul, 'mul')
Div, div, DivInplace, div_inplace = _broadcast(scal.Div, 'Div') div, div_inplace = _elemwise(scal.div, 'div')
Pow, pow, PowInplace, pow_inplace = _broadcast(scal.Pow, 'Pow') pow, pow_inplace = _elemwise(scal.pow, 'pow')
########################## ##########################
# View Operations # View Operations
########################## ##########################
class TransposeInplace(s2t.DimShuffle): class TransposeInplace(Op):
def __init__(self, input): def make_node(self, input):
s2t.DimShuffle.__init__(self, input, range(len(input.broadcastable)-1, -1, -1), True) return Apply(self, [input], [input.type()])
def perform(self): def perform(self, node, (x, ), (z, )):
self.outputs[0].data = self.inputs[0].data.T z[0] = x.T
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return transpose(gz), return transpose(gz),
def c_code(self, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return """ return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL); PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) { if (%(z)s) {
...@@ -556,319 +527,334 @@ class TransposeInplace(s2t.DimShuffle): ...@@ -556,319 +527,334 @@ class TransposeInplace(s2t.DimShuffle):
%(z)s = transposed; %(z)s = transposed;
""" % locals() """ % locals()
transpose_inplace = gof.op.constructor(TransposeInplace) transpose_inplace = TransposeInplace()
def transpose(x, **kwargs): def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs) return transpose_inplace(tensor_copy(x), **kwargs)
class Subtensor(Op, Viewer): # class Subtensor(Op, Viewer):
nin = 2 # nin = 2
nout = 1 # nout = 1
e_invalid = 'invalid index' # e_invalid = 'invalid index'
debug = 0 # debug = 0
def __init__(self, *args,**kwargs): # def __init__(self, *args,**kwargs):
def as_tuple_result(obj): # def as_tuple_result(obj):
if isinstance(obj, Result): # if isinstance(obj, Result):
return obj # return obj
r = gof.result.PythonResult(None) # r = gof.result.PythonResult(None)
if isinstance(obj, tuple): # if isinstance(obj, tuple):
r.data = obj # r.data = obj
else: # else:
r.data = (obj,) # r.data = (obj,)
return r # return r
def pad(tplR, N): # def pad(tplR, N):
l = list(tplR.data) # l = list(tplR.data)
for i in range(len(l), N): # for i in range(len(l), N):
l.append(slice(0,sys.maxint,1)) # l.append(slice(0,sys.maxint,1))
tplR.data = tuple(l) # tplR.data = tuple(l)
if Subtensor.debug: # if Subtensor.debug:
print 'Subtensor.__init__', args, kwargs # print 'Subtensor.__init__', args, kwargs
#Olivier says not to call this # #Olivier says not to call this
#Op.__init__(self, *args,**kwargs) # #Op.__init__(self, *args,**kwargs)
#Viewer.__init__(self, *args,**kwargs) # #Viewer.__init__(self, *args,**kwargs)
t, coord = args # t, coord = args
t = _as_tensor(t) # t = _as_tensor(t)
coord = as_tuple_result(coord) # coord = as_tuple_result(coord)
if len(coord.data) > len(t.broadcastable): # if len(coord.data) > len(t.broadcastable):
raise ValueError(Subtensor.e_invalid) # raise ValueError(Subtensor.e_invalid)
# add the implicit extra unbounded slices # # add the implicit extra unbounded slices
# e.g. n[0] on a 3d tensor pads to n[0,:,:] # # e.g. n[0] on a 3d tensor pads to n[0,:,:]
pad(coord, len(t.broadcastable)) # pad(coord, len(t.broadcastable))
broadcastable = [0 for c in coord.data if isinstance(c, slice)] # broadcastable = [0 for c in coord.data if isinstance(c, slice)]
if Subtensor.debug: # if Subtensor.debug:
print 'brdcstble', broadcastable # print 'brdcstble', broadcastable
print 't', t.data # print 't', t.data
print 'coord', coord.data # print 'coord', coord.data
self.inputs = [t, coord] # self.inputs = [t, coord]
self.outputs = [Tensor(t.dtype, broadcastable)] # self.outputs = [Tensor(t.dtype, broadcastable)]
def view_map(self): # def view_map(self):
return {self.out: [self.inputs[0]]} # return {self.out: [self.inputs[0]]}
def perform(self): # def perform(self):
x = self.inputs[0].data # x = self.inputs[0].data
c = self.inputs[1].data # c = self.inputs[1].data
if Subtensor.debug: # if Subtensor.debug:
print 'perform: x', x # print 'perform: x', x
print 'perform: c', c # print 'perform: c', c
if len(c) == 1: # if len(c) == 1:
self.outputs[0].data = x.__getitem__(c[0]) # self.outputs[0].data = x.__getitem__(c[0])
else: # else:
self.outputs[0].data = x.__getitem__(c) # self.outputs[0].data = x.__getitem__(c)
def grad(self, (x,), (gz,)): # def grad(self, (x,), (gz,)):
# - option: allocate a potentially large matrix of zeros, and fill in # # - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz # # the appropriate elements from gz
# - option: return a sparse matrix # # - option: return a sparse matrix
# - option: return gz, but think about how to include a special addition # # - option: return gz, but think about how to include a special addition
# function that works on a corresponding view of the original data # # function that works on a corresponding view of the original data
raise NotImplementedError() # raise NotImplementedError()
subtensor = gof.op.constructor(Subtensor) # subtensor = gof.op.constructor(Subtensor)
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
class Dot(_Op): class Dot(Op):
nin=2 # nin=2
nout=1 # nout=1
def propagate_broadcastable(self, bx, by): def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
if len(inputs) != 2:
raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self)
i_broadcastables = [input.type.broadcastable for input in inputs]
i_dtypes = [input.type.dtype for input in inputs]
# o_broadcastables = utils.from_return_values(self.propagate_broadcastable(*i_broadcastables))
# o_dtypes = utils.from_return_values(self.propagate_dtype(*i_dtypes))
bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar if len(bx) == 0: # x is a scalar
rval = by bz = by
else: else:
if len(by) >= 2: #y is a matrix or tensor if len(by) >= 2: #y is a matrix or tensor
rval = bx[:-1] + by[:-2] + by[-1:] bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by)==1: #y is vector elif len(by)==1: #y is vector
rval = bx[:-1] bz = bx[:-1]
else: #y is a scalar else: #y is a scalar
rval = bx bz = bx
return [rval] o_broadcastables = [bz]
def impl(self, x, y): o_dtypes = [scal.upcast(*i_dtypes)]
return numpy.dot(x, y)
def grad(self, (x, y), (gz,)):
return dot(gz, y.T), dot(x.T, gz)
dot = gof.op.constructor(Dot)
class Gemm(_Op):
nin=5
nout=1
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z aliased to x or y'
debug = False
def __init__(self, *args, **kwargs):
_Op.__init__(self, *args, **kwargs)
z, a, x, y, b = self.inputs
zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
def destroy_map(self):
return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb):
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
return [bz]
def impl(self, z, a, x, y, b):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
return z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
return z
def grad(self, (z, a, x, y, b), (gz,)):
raise NotImplementedError()
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_validate_update(self, *args):
return ""
def c_validate_update_cleanup(self, *args):
return ""
def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = %(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE) outputs = [Tensor(t, b)() for b, t in zip(o_broadcastables, o_dtypes)]
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE) return Apply(self, inputs, outputs)
&& (%(_z)s->descr->type_num != PyArray_FLOAT)) def perform(self, node, (x, y), (z, )):
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} z[0] = numpy.dot(x, y)
def grad(self, (x, y), (gz,)):
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num) return dot(gz, y.T), dot(x.T, gz)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num)) dot = Dot()
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
# class Gemm(_Op):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) # nin=5
{ # nout=1
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree"); # E_rank = 'gemm only works for rank 2'
%(fail)s; # E_scalar = 'gemm requires scalar argument'
} # E_z_uniq = 'argument z aliased to x or y'
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size) # debug = False
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) # def __init__(self, *args, **kwargs):
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)) # _Op.__init__(self, *args, **kwargs)
{ # z, a, x, y, b = self.inputs
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s; # zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
} # if zr.intersection(xr):
# raise ValueError(Gemm.E_z_uniq, (z, x))
/* # if zr.intersection(yr):
encode the stride structure of _x,_y,_z into a single integer # raise ValueError(Gemm.E_z_uniq, (z, y))
*/ # def destroy_map(self):
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8; # return {self.out:[self.inputs[0]]}
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; # def propagate_broadcastable(self, bz, ba, bx, by, bb):
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0; # if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
# if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
/* create appropriate strides for malformed matrices that are row or column # if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
* vectors # if len(ba): raise ValueError(Gemm.E_scalar, ba)
*/ # if len(bb): raise ValueError(Gemm.E_scalar, bb)
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0]; # return [bz]
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1]; # def impl(self, z, a, x, y, b):
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0]; # assert a.shape == ()
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1]; # assert b.shape == ()
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0]; # if z.shape == ():
# z.itemset(z*a + b*numpy.dot(x,y))
switch (type_num) # return z
{ # else:
case PyArray_FLOAT: # if b == 0.0:
{ # if a == 1.0:
#define REAL float # z[:] = numpy.dot(x,y)
float a = (%(_a)s->descr->type_num == PyArray_FLOAT) # elif a == -1.0:
? (REAL)(((float*)%(_a)s->data)[0]) # z[:] = -numpy.dot(x,y)
: (REAL)(((double*)%(_a)s->data)[0]); # else:
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ? # z[:] = a * numpy.dot(x,y)
(REAL)(((float*)%(_b)s->data)[0]) # elif b == 1.0:
: (REAL)(((double*)%(_b)s->data)[0]); # if a == 1.0:
# z += numpy.dot(x,y)
float* x = (float*)PyArray_DATA(%(_x)s); # elif a == -1.0:
float* y = (float*)PyArray_DATA(%(_y)s); # z -= numpy.dot(x,y)
float* z = (float*)PyArray_DATA(%(_z)s); # else:
char N = 'N'; # z += a * numpy.dot(x,y)
char T = 'T'; # else:
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; # z *= b
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n'; # z += a * numpy.dot(x,y)
switch(unit) # return z
{ # def grad(self, (z, a, x, y, b), (gz,)):
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; # raise NotImplementedError()
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; # def c_support_code(self):
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break; # #return blas.cblas_header_text()
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; # mod_str = """
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; # #ifndef MOD
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; # #define MOD %
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; # #endif
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s; # """
}; # return blas.blas_proto() + mod_str
#undef REAL # def c_headers(self):
} # return ['<iostream>']
break; # def c_libraries(self):
case PyArray_DOUBLE: # return blas.ldflags()
{ # def c_validate_update(self, *args):
#define REAL double # return ""
# def c_validate_update_cleanup(self, *args):
double a = (%(_a)s->descr->type_num == PyArray_FLOAT) # return ""
? (REAL)(((float*)%(_a)s->data)[0]) # def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
: (REAL)(((double*)%(_a)s->data)[0]); # return """
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ? # int unit = 0;
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]); # int type_num = %(_x)s->descr->type_num;
double* x = (double*)PyArray_DATA(%(_x)s); # int type_size = %(_x)s->descr->elsize; // in bytes
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s); # npy_intp* Nx = %(_x)s->dimensions;
char N = 'N'; # npy_intp* Ny = %(_y)s->dimensions;
char T = 'T'; # npy_intp* Nz = %(_z)s->dimensions;
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n'; # npy_intp* Sx = %(_x)s->strides;
switch(unit) # npy_intp* Sy = %(_y)s->strides;
{ # npy_intp* Sz = %(_z)s->strides;
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break; # //strides for x, y, z in dimensions 0, 1
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; # int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; # if (%(_zout)s != %(_z)s)
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; # {
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; # if (%(_zout)s)
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; # {
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s; # Py_DECREF(%(_zout)s);
}; # }
#undef REAL # %(_zout)s = %(_z)s;
} # Py_INCREF(%(_zout)s);
break; # }
}
# if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
""" % dict(locals(), **sub) # if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
gemm = gof.op.constructor(Gemm) # if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
# if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
# && (%(_a)s->descr->type_num != PyArray_FLOAT))
# {PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
# if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
# && (%(_b)s->descr->type_num != PyArray_FLOAT))
# {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
# if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
# && (%(_x)s->descr->type_num != PyArray_FLOAT))
# {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
# if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
# && (%(_y)s->descr->type_num != PyArray_FLOAT))
# {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
# if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
# && (%(_z)s->descr->type_num != PyArray_FLOAT))
# {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
# if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
# ||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
# { PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
# if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
# {
# PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
# %(fail)s;
# }
# if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
# || (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
# || (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
# {
# PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
# }
# /*
# encode the stride structure of _x,_y,_z into a single integer
# */
# unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
# unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
# unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
# /* create appropriate strides for malformed matrices that are row or column
# * vectors
# */
# sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
# sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
# sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
# sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
# sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
# sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
# switch (type_num)
# {
# case PyArray_FLOAT:
# {
# #define REAL float
# float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
# ? (REAL)(((float*)%(_a)s->data)[0])
# : (REAL)(((double*)%(_a)s->data)[0]);
# float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
# (REAL)(((float*)%(_b)s->data)[0])
# : (REAL)(((double*)%(_b)s->data)[0]);
# float* x = (float*)PyArray_DATA(%(_x)s);
# float* y = (float*)PyArray_DATA(%(_y)s);
# float* z = (float*)PyArray_DATA(%(_z)s);
# char N = 'N';
# char T = 'T';
# int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
# //std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
# switch(unit)
# {
# case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
# case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
# case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
# case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
# case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
# case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
# case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
# case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
# default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
# };
# #undef REAL
# }
# break;
# case PyArray_DOUBLE:
# {
# #define REAL double
# double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
# ? (REAL)(((float*)%(_a)s->data)[0])
# : (REAL)(((double*)%(_a)s->data)[0]);
# double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
# (REAL)(((float*)%(_b)s->data)[0])
# : (REAL)(((double*)%(_b)s->data)[0]);
# double* x = (double*)PyArray_DATA(%(_x)s);
# double* y = (double*)PyArray_DATA(%(_y)s);
# double* z = (double*)PyArray_DATA(%(_z)s);
# char N = 'N';
# char T = 'T';
# int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
# //std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
# switch(unit)
# {
# case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
# case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
# case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
# case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
# case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
# case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
# case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
# case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
# default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
# };
# #undef REAL
# }
# break;
# }
# """ % dict(locals(), **sub)
# gemm = gof.op.constructor(Gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论