changed the way modes work, finalized a Linker interface, improved cc, some work…

changed the way modes work, finalized a Linker interface, improved cc, some work on scalar/scalar_op and thinking I should commit more often
上级 6ec9db61
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
import unittest import unittest
from gof import ResultBase, Op, Env, modes from gof import ResultBase, Op, Env, modes
import gof
from scalar_ops import * from scalar_ops import *
def inputs(): def inputs():
x = modes.BuildEvalMode(as_scalar(1.0, 'x')) x = modes.build_eval(as_scalar(1.0, 'x'))
y = modes.BuildEvalMode(as_scalar(2.0, 'y')) y = modes.build_eval(as_scalar(2.0, 'y'))
z = modes.BuildEvalMode(as_scalar(3.0, 'z')) z = modes.build_eval(as_scalar(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
...@@ -24,7 +24,15 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -24,7 +24,15 @@ class _test_ScalarOps(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
assert e.r.data == 1.5 assert e.data == 1.5
def test_1(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
g = env([x, y], [e])
fn = gof.cc.CLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
assert e.data == 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
import op, result, ext, link, env, features, toolbox, graph import op, result, ext, link, env, features, toolbox, graph, cc
from op import * from op import *
from result import * from result import *
...@@ -8,7 +8,7 @@ from link import * ...@@ -8,7 +8,7 @@ from link import *
from env import * from env import *
from features import * from features import *
from toolbox import * from toolbox import *
from cc import *
......
...@@ -105,14 +105,14 @@ import modes ...@@ -105,14 +105,14 @@ import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(Double(1.0, 'x')) x = modes.build(Double(1.0, 'x'))
y = modes.BuildMode(Double(2.0, 'y')) y = modes.build(Double(2.0, 'y'))
z = modes.BuildMode(Double(3.0, 'z')) z = modes.build(Double(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
...@@ -121,21 +121,38 @@ class _test_CLinker(unittest.TestCase): ...@@ -121,21 +121,38 @@ class _test_CLinker(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]), [x.r, y.r, z.r], [e.r]) lnk = CLinker(env([x, y, z], [e])) #, [x.r, y.r, z.r], [e.r])
cgen = lnk.code_gen() cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r, z.r], [e.r]) fn = lnk.make_function() #[x.r, y.r, z.r], [e.r])
print fn(2.0, 2.0, 2.0) print fn(2.0, 2.0, 2.0)
# fn = 0 # fn = 0
def test_1(self): def test_1(self):
x, y, z = inputs() x, y, z = inputs()
z.r.constant = True z.constant = True
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e]), [x.r, y.r], [e.r]) lnk = CLinker(env([x, y], [e])) #, [x.r, y.r], [e.r])
cgen = lnk.code_gen() cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r], [e.r]) fn = lnk.make_function() #[x.r, y.r], [e.r])
print fn(2.0, 2.0) print fn(2.0, 2.0)
# fn = 0 # fn = 0
def test_2(self):
x, y, z = inputs()
op = Add(x, y)
lnk = CLinker(op)
cgen = lnk.code_gen()
fn = lnk.make_function() #[x.r, y.r], [op.out])
print fn(2.0, 7.0)
# fn = 0
def test_3(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e]))
fn = lnk.make_function()
print fn(2.0, 2.0, 2.0)
# fn = 0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -81,15 +81,21 @@ s2t = OpSubOptimizer(Sigmoid, TransposeView) ...@@ -81,15 +81,21 @@ s2t = OpSubOptimizer(Sigmoid, TransposeView)
import modes import modes
modes.make_constructors(globals(), name_filter = lambda x:x) modes.make_constructors(globals(), name_filter = lambda x:x)
# def inputs():
# x = modes.BuildMode(MyResult('x'))
# y = modes.BuildMode(MyResult('y'))
# z = modes.BuildMode(MyResult('z'))
# return x, y, z
def inputs(): def inputs():
x = modes.BuildMode(MyResult('x')) x = modes.build(MyResult('x'))
y = modes.BuildMode(MyResult('y')) y = modes.build(MyResult('y'))
z = modes.BuildMode(MyResult('z')) z = modes.build(MyResult('z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs] inputs = [input for input in inputs]
outputs = [output.r for output in outputs] outputs = [output for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
...@@ -139,7 +145,7 @@ class _test_all(unittest.TestCase): ...@@ -139,7 +145,7 @@ class _test_all(unittest.TestCase):
e = Dot(AddInPlace(x,y), TransposeView(x)) e = Dot(AddInPlace(x,y), TransposeView(x))
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e.r.owner.inputs[1], Add(x,z).r) g.replace(e.owner.inputs[1], Add(x,z))
assert g.consistent() assert g.consistent()
def test_5(self): def test_5(self):
...@@ -147,7 +153,7 @@ class _test_all(unittest.TestCase): ...@@ -147,7 +153,7 @@ class _test_all(unittest.TestCase):
e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x)))))) e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x))))))
g = env([x,y,z], [e]) g = env([x,y,z], [e])
assert g.consistent() assert g.consistent()
g.replace(e.r.owner.inputs[1].owner.inputs[0], x.r, False) g.replace(e.owner.inputs[1].owner.inputs[0], x, False)
assert not g.consistent() assert not g.consistent()
def test_6(self): def test_6(self):
...@@ -168,9 +174,9 @@ class _test_all(unittest.TestCase): ...@@ -168,9 +174,9 @@ class _test_all(unittest.TestCase):
chk = g.checkpoint() chk = g.checkpoint()
dtv_elim.optimize(g) dtv_elim.optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
g.replace(g.equiv(e.r), Add(x,y).r) g.replace(g.equiv(e), Add(x,y))
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e.r), Dot(AddInPlace(x,y), TransposeView(x)).r, False) g.replace(g.equiv(e), Dot(AddInPlace(x,y), TransposeView(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent() assert not g.consistent()
g.revert(chk) g.revert(chk)
...@@ -188,21 +194,21 @@ class _test_all(unittest.TestCase): ...@@ -188,21 +194,21 @@ class _test_all(unittest.TestCase):
def test_9(self): def test_9(self):
x, y, z = inputs() x, y, z = inputs()
x.r.indestructible = True x.indestructible = True
e = AddInPlace(x, y) e = AddInPlace(x, y)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e.r, Add(x, y).r) g.replace(e, Add(x, y))
assert g.consistent() assert g.consistent()
def test_10(self): def test_10(self):
x, y, z = inputs() x, y, z = inputs()
x.r.indestructible = True x.indestructible = True
tv = TransposeView(x) tv = TransposeView(x)
e = AddInPlace(tv, y) e = AddInPlace(tv, y)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv.r, Sigmoid(x).r) g.replace(tv, Sigmoid(x))
assert g.consistent() assert g.consistent()
......
...@@ -65,19 +65,18 @@ import modes ...@@ -65,19 +65,18 @@ import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(Double(1.0, 'x')) x = modes.build(Double(1.0, 'x'))
y = modes.BuildMode(Double(2.0, 'y')) y = modes.build(Double(2.0, 'y'))
z = modes.BuildMode(Double(3.0, 'z')) z = modes.build(Double(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker(env) lnk = PerformLinker(env)
lnk.compile()
return lnk return lnk
...@@ -86,8 +85,23 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -86,8 +85,23 @@ class _test_PerformLinker(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
perform_linker(env([x, y, z], [e])).run() fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True)
assert e.r.data == 1.5 fn()
assert e.data == 1.5
def test_1(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn()
assert e.data != 1.5
def test_2(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,7 +22,7 @@ class Double(ResultBase): ...@@ -22,7 +22,7 @@ class Double(ResultBase):
return self.name return self.name
def __add__(self, other): def __add__(self, other):
return Add(self, other) return add(self, other)
def convert(x): def convert(x):
...@@ -80,51 +80,51 @@ def inputs(mode): ...@@ -80,51 +80,51 @@ def inputs(mode):
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [], consistency_check = validate) return Env(inputs, outputs, features = [], consistency_check = validate)
class _test_Modes(unittest.TestCase): class _test_Modes(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs(BuildMode) x, y, z = inputs(build)
e = add(add(x, y), z) e = add(add(x, y), z)
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 0.0 assert e.data == 0.0
def test_1(self): def test_1(self):
x, y, z = inputs(BuildEvalMode) x, y, z = inputs(build_eval)
e = add(add(x, y), z) e = add(add(x, y), z)
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 6.0 assert e.data == 6.0
def test_2(self): def test_2(self):
x, y, z = inputs(EvalMode) x, y, z = inputs(eval)
e = add(add(x, y), z) e = add(add(x, y), z)
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add_R]" assert str(g) == "[Add_R]"
assert e.r.data == 6.0 assert e.data == 6.0
def test_3(self): def test_3(self):
x, y, z = inputs(BuildMode) x, y, z = inputs(build)
e = x + y + z e = x + y + z
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 0.0 assert e.data == 0.0
def test_4(self): def test_4(self):
x, y, z = inputs(BuildEvalMode) x, y, z = inputs(build_eval)
e = x + 34.0 e = x + 34.0
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(x, oignon)]" assert str(g) == "[Add(x, oignon)]"
assert e.r.data == 35.0 assert e.data == 35.0
def test_5(self): def test_5(self):
xb, yb, zb = inputs(BuildMode) xb, yb, zb = inputs(build)
xe, ye, ze = inputs(EvalMode) xe, ye, ze = inputs(eval)
try: try:
e = xb + ye e = xb + ye
except TypeError: except TypeError:
......
...@@ -49,14 +49,14 @@ modes.make_constructors(globals()) ...@@ -49,14 +49,14 @@ modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(MyResult('x')) x = modes.build(MyResult('x'))
y = modes.BuildMode(MyResult('y')) y = modes.build(MyResult('y'))
z = modes.BuildMode(MyResult('z')) z = modes.build(MyResult('z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
......
...@@ -47,14 +47,14 @@ import modes ...@@ -47,14 +47,14 @@ import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(MyResult('x')) x = modes.build(MyResult('x'))
y = modes.BuildMode(MyResult('y')) y = modes.build(MyResult('y'))
z = modes.BuildMode(MyResult('z')) z = modes.build(MyResult('z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
...@@ -65,10 +65,10 @@ class _test_EquivTool(unittest.TestCase): ...@@ -65,10 +65,10 @@ class _test_EquivTool(unittest.TestCase):
sx = sigmoid(x) sx = sigmoid(x)
e = add(sx, sigmoid(y)) e = add(sx, sigmoid(y))
g = env([x, y, z], [e], features = [EquivTool]) g = env([x, y, z], [e], features = [EquivTool])
assert g.equiv(sx.r) is sx.r assert g.equiv(sx) is sx
g.replace(sx.r, dot(x, z).r) g.replace(sx, dot(x, z))
assert g.equiv(sx.r) is not sx.r assert g.equiv(sx) is not sx
assert isinstance(g.equiv(sx.r).owner, Dot) assert isinstance(g.equiv(sx).owner, Dot)
......
...@@ -90,9 +90,9 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -90,9 +90,9 @@ def struct_gen(args, struct_builders, blocks, sub):
behavior = code_gen(blocks) behavior = code_gen(blocks)
storage_decl = "\n".join(["PyObject* %s;" % arg for arg in args]) storage_decl = "\n".join(["PyObject* %s;" % arg for arg in args])
# we're borrowing the references to the storage pointers because Python
# has (needs) references to them to feed inputs or get the results
storage_set = "\n".join(["this->%s = %s;" % (arg, arg) for arg in args]) storage_set = "\n".join(["this->%s = %s;" % (arg, arg) for arg in args])
storage_incref = "\n".join(["Py_XINCREF(%s);" % arg for arg in args])
storage_decref = "\n".join(["Py_XDECREF(this->%s);" % arg for arg in args])
args_names = ", ".join(args) args_names = ", ".join(args)
args_decl = ", ".join(["PyObject* %s" % arg for arg in args]) args_decl = ", ".join(["PyObject* %s" % arg for arg in args])
...@@ -139,6 +139,7 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -139,6 +139,7 @@ def struct_gen(args, struct_builders, blocks, sub):
} }
int init(PyObject* __ERROR, %(args_decl)s) { int init(PyObject* __ERROR, %(args_decl)s) {
%(storage_incref)s
%(storage_set)s %(storage_set)s
int %(failure_var)s = 0; int %(failure_var)s = 0;
%(struct_init_head)s %(struct_init_head)s
...@@ -150,6 +151,7 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -150,6 +151,7 @@ def struct_gen(args, struct_builders, blocks, sub):
} }
void cleanup(void) { void cleanup(void) {
%(struct_cleanup)s %(struct_cleanup)s
%(storage_decref)s
} }
int run(void) { int run(void) {
int %(failure_var)s = 0; int %(failure_var)s = 0;
...@@ -224,38 +226,34 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub): ...@@ -224,38 +226,34 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
class CLinker(Linker): class CLinker(Linker):
def __init__(self, env, inputs = None, outputs = None): def __init__(self, env):
self.env = env self.env = env
self.inputs = inputs self.fetch_results()
self.outputs = outputs
def fetch_results(self): def fetch_results(self):
env = self.env env = self.env
results = env.results()
if self.inputs: self.inputs = env.inputs
assert set(self.inputs) == set(env.inputs) self.outputs = env.outputs
inputs = self.inputs
else:
inputs = env.inputs
if self.outputs: try: self.results = list(env.results())
assert set(self.outputs) == set(env.outputs) except AttributeError: self.results = self.inputs + self.outputs
outputs = self.outputs
else: try: self.orphans = list(env.orphans())
outputs = env.outputs except AttributeError: self.orphans = []
outputs = env.outputs try: self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
orphans = env.orphans() except AttributeError: self.temps = []
temps = results.difference(inputs).difference(outputs).difference(orphans)
return results, inputs, outputs, orphans, temps try: self.op_order = env.toposort()
except AttributeError: self.op_order = [env]
def code_gen(self, reuse_storage = True): def code_gen(self, reuse_storage = True):
env = self.env if getattr(self, 'struct_code', False) and self.reuse_storage == reuse_storage:
op_order = env.toposort() return self.struct_code
results, inputs, outputs, orphans, temps = self.fetch_results() env = self.env
consts = [] consts = []
...@@ -272,32 +270,32 @@ class CLinker(Linker): ...@@ -272,32 +270,32 @@ class CLinker(Linker):
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
for result in results: for result in self.results:
if getattr(result, 'constant', False): if getattr(result, 'constant', False):
if result in outputs or result in temps: if result in self.outputs or result in self.temps:
raise Exception("Temporaries and outputs should not be marked constant. Check your graph.") raise Exception("Temporaries and outputs should not be marked constant. Check your graph.")
try: try:
symbol[result] = result.c_literal() symbol[result] = result.c_literal()
consts.append(result) consts.append(result)
if result in inputs: if result in self.inputs:
print "Warning: input %s is marked as constant and has been compiled as a literal." % result print "Warning: input %s is marked as constant and has been compiled as a literal." % result
elif result in orphans: elif result in self.orphans:
orphans.remove(result) self.orphans.remove(result)
continue continue
except AbstractFunctionError: except AbstractFunctionError:
pass pass
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction], # policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]] # [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
if result in inputs: if result in self.inputs:
# we need to extract the new inputs at each run # we need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]] [get_c_declare, get_c_extract, get_c_cleanup]]
elif result in orphans: elif result in self.orphans:
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same # orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup], policy = [[get_c_declare, get_c_extract, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif result in temps or not reuse_storage: elif result in self.temps or not reuse_storage:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract # temps don't need to be extracted from Python, so we call c_init rather than c_extract
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync
if result.c_is_simple() or not reuse_storage: if result.c_is_simple() or not reuse_storage:
...@@ -307,7 +305,7 @@ class CLinker(Linker): ...@@ -307,7 +305,7 @@ class CLinker(Linker):
# it is useful for complex temps to reuse storage at each run, so we only clean up in the destructor # it is useful for complex temps to reuse storage at each run, so we only clean up in the destructor
policy = [[get_c_declare, get_c_init, get_c_cleanup], policy = [[get_c_declare, get_c_init, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif result in outputs: elif result in self.outputs:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract # outputs don't need to be extracted from Python, so we call c_init rather than c_extract
if result.c_is_simple() or not reuse_storage: if result.c_is_simple() or not reuse_storage:
...@@ -328,9 +326,7 @@ class CLinker(Linker): ...@@ -328,9 +326,7 @@ class CLinker(Linker):
id += 2 id += 2
print symbol for op in self.op_order:
for op in op_order:
ivnames, ovnames = op.c_var_names() ivnames, ovnames = op.c_var_names()
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
...@@ -365,17 +361,9 @@ class CLinker(Linker): ...@@ -365,17 +361,9 @@ class CLinker(Linker):
args = [] args = []
in_arg_order = [] in_arg_order = []
for result in list(inputs):
in_arg_order.append(result) args += ["storage_%s" % symbol[result] for result in self.inputs + self.outputs + self.orphans]
args.append("storage_%s" % symbol[result])
out_arg_order = []
for result in list(outputs):
out_arg_order.append(result)
args.append("storage_%s" % symbol[result])
orphan_arg_order = []
for result in list(orphans):
orphan_arg_order.append(result)
args.append("storage_%s" % symbol[result])
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var)) struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var))
hash = md5.md5(struct_code).hexdigest() hash = md5.md5(struct_code).hexdigest()
...@@ -383,19 +371,16 @@ class CLinker(Linker): ...@@ -383,19 +371,16 @@ class CLinker(Linker):
struct_code %= dict(name = struct_name) struct_code %= dict(name = struct_name)
self.struct_code = struct_code self.struct_code = struct_code
self.reuse_storage = reuse_storage
self.struct_name = struct_name self.struct_name = struct_name
self.hash = hash self.hash = hash
self.args = args self.args = args
self.inputs = in_arg_order
self.outputs = out_arg_order
self.orphans = orphan_arg_order
self.r2symbol = symbol self.r2symbol = symbol
self.init_blocks = init_blocks self.init_blocks = init_blocks
self.init_tasks = init_tasks self.init_tasks = init_tasks
self.blocks = blocks self.blocks = blocks
self.tasks = tasks self.tasks = tasks
return struct_code
def find_task(self, failure_code): def find_task(self, failure_code):
n = len(self.init_tasks) n = len(self.init_tasks)
...@@ -406,393 +391,209 @@ class CLinker(Linker): ...@@ -406,393 +391,209 @@ class CLinker(Linker):
def support_code(self): def support_code(self):
ret = "" ret = ""
for x in self.env.results().union(self.env.ops()): for x in self.results + self.op_order:
try: ret += x.c_support_code() try: ret += x.c_support_code()
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
def compile_args(self): def compile_args(self):
ret = set() ret = set()
for x in self.env.results().union(self.env.ops()): for x in self.results + self.op_order:
try: ret.update(x.c_compile_args()) try: ret.update(x.c_compile_args())
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
def headers(self): def headers(self):
ret = set() ret = set()
for x in self.env.results().union(self.env.ops()): for x in self.results + self.op_order:
try: ret.update(x.c_headers()) try: ret.update(x.c_headers())
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
def libraries(self): def libraries(self):
ret = set() ret = set()
for x in self.env.results().union(self.env.ops()): for x in self.results + self.op_order:
try: ret.update(x.c_libraries()) try: ret.update(x.c_libraries())
except AbstractFunctionError: pass except AbstractFunctionError: pass
return ret return ret
def make_function(self, in_order, out_order): # def make_function(self, in_order, out_order):
nin = len(self.inputs) # nin = len(self.inputs)
nout = len(self.outputs) # nout = len(self.outputs)
if nin != len(in_order): # if nin != len(in_order):
raise TypeError("Wrong number of inputs.") # raise TypeError("Wrong number of inputs.")
if nout != len(out_order): # if nout != len(out_order):
raise TypeError("Wrong number of outputs.") # raise TypeError("Wrong number of outputs.")
in_storage = [] # in_storage = []
out_storage = [] # out_storage = []
cthunk_in_args = [None] * nin # cthunk_in_args = [None] * nin
cthunk_out_args = [None] * nout # cthunk_out_args = [None] * nout
for result in in_order: # for result in in_order:
idx = self.inputs.index(result) # idx = self.inputs.index(result)
storage = [None] # storage = [None]
cthunk_in_args[idx] = storage # cthunk_in_args[idx] = storage
in_storage.append(storage) # in_storage.append(storage)
for result in out_order: # for result in out_order:
idx = self.outputs.index(result) # idx = self.outputs.index(result)
storage = [None] # storage = [None]
cthunk_out_args[idx] = storage # cthunk_out_args[idx] = storage
out_storage.append(storage) # out_storage.append(storage)
for arg in cthunk_in_args + cthunk_out_args: # for arg in cthunk_in_args + cthunk_out_args:
if arg is None: # if arg is None:
raise Exception("The inputs or outputs are underspecified.") # raise Exception("The inputs or outputs are underspecified.")
# error_storage = [None, None, None]
# cthunk = self.cthunk_factory(error_storage, cthunk_in_args, cthunk_out_args)
# def execute(*args):
# for arg, storage in zip(args, in_storage):
# storage[0] = arg
# failure = cutils.run_cthunk(cthunk)
# if failure:
# raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
# return utils.to_return_values([storage[0] for storage in out_storage])
# return execute
def __compile__(self, inplace = False):
if inplace:
in_results = self.inputs
out_results = self.outputs
else:
in_results = [copy(input) for input in self.inputs]
out_results = [copy(output) for output in self.outputs]
error_storage = [None, None, None] error_storage = [None, None, None]
cthunk = self.cthunk_factory(error_storage, cthunk_in_args, cthunk_out_args) thunk = self.cthunk_factory(error_storage,
[result._data for result in in_results],
[result._data for result in out_results])
if not inplace:
for r in in_results + out_results:
r._role = None # we just need the wrapper, not the (copied) graph associated to it
return thunk, in_results, out_results, error_storage
def make_thunk(self, inplace = False):
cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
def execute():
failure = cutils.run_cthunk(cthunk)
if failure:
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
return execute, in_results, out_results
def make_function(self, inplace = False):
cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
# out_storage = [result._data for result in out_results]
def execute(*args): def execute(*args):
for arg, storage in zip(args, in_storage): for arg, result in zip(args, in_results):
storage[0] = arg result.data = arg
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
if failure: if failure:
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1)) raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
return utils.to_return_values([storage[0] for storage in out_storage]) return utils.to_return_values([result.data for result in out_results])
# return utils.to_return_values([storage[0] for storage in out_storage])
return execute return execute
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
cthunk = object() if not getattr(self, 'instantiate', False):
module_name = self.hash self.code_gen()
mod = weave.ext_tools.ext_module(module_name)
cthunk = object()
argnames = ["i%i" % i for i in xrange(len(in_storage))] \ module_name = self.hash
+ ["o%i" % i for i in xrange(len(out_storage))] \ mod = weave.ext_tools.ext_module(module_name)
+ ["orph%i" % i for i in xrange(len(self.orphans))]
argnames = ["i%i" % i for i in xrange(len(in_storage))] \
code = """ + ["o%i" % i for i in xrange(len(out_storage))] \
%(struct_name)s* struct_ptr = new %(struct_name)s(); + ["orph%i" % i for i in xrange(len(self.orphans))]
struct_ptr->init(error_storage, %(args)s);
PyObject* thunk = PyCObject_FromVoidPtrAndDesc((void*)(&%(struct_name)s_executor), struct_ptr, %(struct_name)s_destructor); code = """
return thunk; %(struct_name)s* struct_ptr = new %(struct_name)s();
// return_val = thunk; // oh my god weave why does this leak >:\ struct_ptr->init(error_storage, %(args)s);
""" % dict(struct_name = self.struct_name, PyObject* thunk = PyCObject_FromVoidPtrAndDesc((void*)(&%(struct_name)s_executor), struct_ptr, %(struct_name)s_destructor);
args = ", ".join(argnames)) return thunk;
// return_val = thunk; // oh my god weave why does this leak >:\
d = dict(error_storage = object()) """ % dict(struct_name = self.struct_name,
for argname in argnames: args = ", ".join(argnames))
d[argname] = object()
d = dict(error_storage = object())
instantiate = weave.ext_tools.ext_function('instantiate', for argname in argnames:
code, d[argname] = object()
['error_storage'] + argnames,
local_dict = d, instantiate = weave.ext_tools.ext_function('instantiate',
global_dict = {}) code,
['error_storage'] + argnames,
static = """ local_dict = d,
int %(struct_name)s_executor(%(struct_name)s* self) { global_dict = {})
return self->run();
} static = """
int %(struct_name)s_executor(%(struct_name)s* self) {
void %(struct_name)s_destructor(void* executor, void* self) { return self->run();
printf("doing cleanup\\n"); }
((%(struct_name)s*)self)->cleanup();
free(self); void %(struct_name)s_destructor(void* executor, void* self) {
} //printf("doing cleanup\\n");
""" % dict(struct_name = self.struct_name) ((%(struct_name)s*)self)->cleanup();
free(self);
}
""" % dict(struct_name = self.struct_name)
instantiate.customize.add_support_code(self.support_code() + self.struct_code + static)
instantiate.customize.add_extra_compile_arg("-w")
for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg)
for header in self.headers():
instantiate.customize.add_header(header)
for lib in self.libraries():
instantiate.customize.add_library(lib)
mod.add_function(instantiate)
mod.compile(location = compile_dir())
module = __import__("%s" % (module_name), {}, {}, [module_name])
self.instantiate = module.instantiate
instantiate.customize.add_support_code(self.support_code() + self.struct_code + static)
for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg)
for header in self.headers():
instantiate.customize.add_header(header)
for lib in self.libraries():
instantiate.customize.add_library(lib)
mod.add_function(instantiate)
mod.compile(location = compile_dir())
module = __import__("%s" % (module_name), {}, {}, [module_name])
ret = module.instantiate(error_storage, *(in_storage + out_storage + [orphan._data for orphan in self.orphans])) ret = module.instantiate(error_storage, *(in_storage + out_storage + [orphan._data for orphan in self.orphans]))
assert sys.getrefcount(ret) == 2 # refcount leak check assert sys.getrefcount(ret) == 2 # refcount leak check
return ret return ret
# def c_thunk_factory(self):
# self.refresh()
# d, names, code, struct, converters = self.c_code()
# cthunk = object()
# module_name = md5.md5(code).hexdigest()
# mod = weave.ext_tools.ext_module(module_name)
# instantiate = weave.ext_tools.ext_function('instantiate',
# code,
# names,
# local_dict = d,
# global_dict = {},
# type_converters = converters)
# instantiate.customize.add_support_code(self.c_support_code() + struct)
# for arg in self.c_compile_args():
# instantiate.customize.add_extra_compile_arg(arg)
# for header in self.c_headers():
# instantiate.customize.add_header(header)
# for lib in self.c_libs():
# instantiate.customize.add_library(lib)
# #add_library_dir
# #print dir(instantiate.customize)
# #print instantiate.customize._library_dirs
# if os.getenv('OMEGA_BLAS_LD_LIBRARY_PATH'):
# instantiate.customize.add_library_dir(os.getenv('OMEGA_BLAS_LD_LIBRARY_PATH'))
# mod.add_function(instantiate)
# mod.compile(location = _compile_dir())
# module = __import__("%s" % (module_name), {}, {}, [module_name])
# def creator():
# return module.instantiate(*[x.data for x in self.inputs + self.outputs])
# return creator
# def code_gen(self, reuse_storage = True):
# env = self.env
# op_order = env.toposort()
# to_extract = env.inputs.union(env.orphans())
# to_sync = env.outputs
# temporaries = env.results().difference(to_extract).difference(to_sync)
# symbol = {}
# init_tasks = []
# tasks = []
# init_blocks = []
# blocks = []
# failure_var = "__failure"
# id = 0
# sub = dict(failure_var = failure_var)
# on_stack = [result for result in temporaries.union(to_sync) if not reuse_storage or result.c_is_simple()]
# for result_set, type in [[to_extract, 'input'],
# [to_sync, 'output'],
# [temporaries, 'temporary']]:
# for result in result_set:
# builder, block = struct_result_codeblocks(result, type, id, symbol, sub, on_stack)
# init_tasks.append((result, 'init'))
# init_blocks.append(builder)
# tasks.append((result, 'get'))
# blocks.append(block)
# id += 2
# for op in op_order:
# ivnames, ovnames = op.c_var_names()
# sub = dict(failure_var = failure_var)
# for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
# sub[vname] = symbol[result]
# # c_validate_update
# try: validate_behavior = op.c_validate_update()
# except AbstractFunctionError:
# validate_behavior = ""
# try: validate_behavior = op.c_validate_update_cleanup()
# except AbstractFunctionError:
# validate_cleanup = ""
# sub['id'] = id
# blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub))
# tasks.append((op, 'validate_update'))
# id += 1
# # c_code
# behavior = op.c_code() # this one must be implemented!
# try: cleanup = op.c_code_cleanup()
# except AbstractFunctionError:
# cleanup = ""
# sub['id'] = id
# blocks.append(CodeBlock("", behavior, cleanup, sub))
# tasks.append((op, 'code'))
# id += 1
# args = []
# in_arg_order = []
# for result in list(to_extract):
# in_arg_order.append(result)
# args.append("storage_%s" % symbol[result])
# out_arg_order = []
# for result in to_sync:
# out_arg_order.append(result)
# args.append("storage_%s" % symbol[result])
# struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var))
# hash = md5.md5(struct_code).hexdigest()
# struct_name = 'compiled_op_%s' % hash
# struct_code %= dict(name = struct_name)
# self.struct_code = struct_code
# self.struct_name = struct_name
# self.hash = hash
# self.args = args
# self.inputs = in_arg_order
# self.outputs = out_arg_order
# self.r2symbol = symbol
# self.init_blocks = init_blocks
# self.init_tasks = init_tasks
# self.blocks = blocks
# self.tasks = tasks
# return struct_code
class OpWiseCLinker(Linker):
def __init__(self, env):
self.env = env
def make_thunk(self, inplace = False):
if inplace:
env = self.env
else:
env = self.env.clone(True)
op_order = env.toposort()
inputs, outputs = env.inputs, env.outputs
env = None
thunks = []
for op in op_order:
cl = CLinker(op)
thunk, in_results, out_results = cl.make_thunk(True)
thunks.append(thunk)
def execute():
for thunk in thunks:
thunk()
# def extract_sync(self, to_extract, to_sync, to_cleanup): return execute, inputs, outputs
# pass
# def code_gen(self):
# env = self.env
# order = env.toposort()
# to_extract = env.inputs.union(env.outputs).union(env.orphans())
# head = ""
# tail = ""
# label_id = 0
# name_id = 0
# result_names = {}
# for result in env.results():
# name = "__v_%i" % name_id
# result_names[result] = name
# name_id += 1
# for result in to_extract:
# head += """
# {
# %(extract)s
# """
# tail = """
# __label_%(label_id)s:
# %(sync)s
# }
# """ + tail
# name = result_names[result]
# type = result.c_type()
# head %= dict(extract = result.c_extract())
# head %= dict(name = name,
# type = type,
# fail = "{goto __label_%i;}" % label_id)
# tail %= dict(sync = result.c_sync(),
# label_id = label_id)
# tail %= dict(name = name,
# type = type)
# label_id += 1
# for op in order:
# inames, onames = op.c_var_names()
# return head + tail
# def struct_result_codeblocks(result, type, id, symbol_table, sub, on_stack):
# if type == 'output':
# sync = get_c_sync(result)
# else:
# sync = ""
# if type == 'input':
# struct_declare = ""
# run_declare = result.c_declare()
# struct_behavior = ""
# run_behavior = get_c_extract(result)
# struct_cleanup = ""
# run_cleanup = get_c_cleanup(result)
# else:
# if result in on_stack:
# struct_declare = ""
# run_declare = result.c_declare()
# struct_behavior = ""
# run_behavior = result.c_init()
# struct_cleanup = ""
# run_cleanup = sync + get_c_cleanup(result)
# else:
# struct_declare = result.c_declare()
# run_declare = ""
# struct_behavior = result.c_init()
# run_behavior = ""
# struct_cleanup = get_c_cleanup(result)
# run_cleanup = sync
# name = "V%i" % id
# symbol_table[result] = name
# sub = copy(sub)
# sub['name'] = name
# sub['id'] = id
# struct_builder = CodeBlock(struct_declare, struct_behavior, struct_cleanup, sub)
# sub['id'] = id + 1
# block = CodeBlock(run_declare, run_behavior, run_cleanup, sub)
# return struct_builder, block
...@@ -75,8 +75,8 @@ class Env(graph.Graph): ...@@ -75,8 +75,8 @@ class Env(graph.Graph):
self._tools = {} self._tools = {}
# The inputs and outputs set bound the subgraph this Env operates on. # The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = set(inputs) self.inputs = list(inputs)
self.outputs = set(outputs) self.outputs = list(outputs)
for feature_class in uniq_features(features): for feature_class in uniq_features(features):
self.add_feature(feature_class, False) self.add_feature(feature_class, False)
...@@ -110,9 +110,9 @@ class Env(graph.Graph): ...@@ -110,9 +110,9 @@ class Env(graph.Graph):
### Public interface ### ### Public interface ###
def add_output(self, output): # def add_output(self, output):
self.outputs.add(output) # self.outputs.add(output)
self.__import_r__([output]) # self.__import_r__([output])
def clients(self, r): def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r." "Set of all the (op, i) pairs such that op.inputs[i] is r."
...@@ -249,8 +249,9 @@ class Env(graph.Graph): ...@@ -249,8 +249,9 @@ class Env(graph.Graph):
new_was_output = True new_was_output = True
if r in self.outputs: if r in self.outputs:
was_output = True was_output = True
self.outputs.remove(r) self.outputs[self.outputs.index(r)] = new_r
self.outputs.add(new_r) # self.outputs.remove(r)
# self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise # The actual replacement operation occurs here. This might raise
# an error. # an error.
...@@ -261,8 +262,9 @@ class Env(graph.Graph): ...@@ -261,8 +262,9 @@ class Env(graph.Graph):
# Restore self.outputs # Restore self.outputs
if was_output: if was_output:
if not new_was_output: if not new_was_output:
self.outputs.remove(new_r) self.outputs[self.outputs.index(new_r)] = r
self.outputs.add(r) # self.outputs.remove(new_r)
# self.outputs.add(r)
# Move back the clients. This should never raise an error. # Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r) self.__move_clients__(clients, new_r, r)
......
...@@ -2,35 +2,80 @@ ...@@ -2,35 +2,80 @@
# from features import Tool # from features import Tool
from utils import AbstractFunctionError from utils import AbstractFunctionError
import utils
class Linker: class Linker:
def __init__(self, env): def __init__(self, env):
self.env = env self.env = env
self.thunk = None
def compile(self): def make_thunk(self, inplace = False):
"""
This function must return a triplet (function, input_results, output_results)
where function is a thunk that operates on the returned results. If inplace
is True, the input_results and output_results lists will be the same as the
inputs and outputs of the graph provided to the Linker. Else, independent
results will be returned.
Example:
e = x + y
env = Env([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(env).make_thunk(inplace)
new_x.data = 1.0
new_y.data = 2.0
fn()
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise AbstractFunctionError() raise AbstractFunctionError()
def run(self): def make_function(self, inplace = False):
self.thunk() """
Returns a function that takes values corresponding to the inputs of the
env used by this Linker and returns values corresponding the the outputs
of that env. If inplace is True, the calculations will operate in the
same storage the env uses, else independent storage will be allocated
for the function.
Example:
e = x + y
env = Env([x, y], [e])
fn = MyLinker(env).make_function(inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
thunk, inputs, outputs = self.make_thunk(inplace)
def execute(*args):
for arg, result in zip(args, inputs):
result.data = arg
thunk()
return utils.to_return_values([result.data for result in outputs])
return execute
def __call__(self):
self.thunk()
class PerformLinker(Linker): class PerformLinker(Linker):
def compile(self): def make_thunk(self, inplace = False):
order = self.env.toposort() if inplace:
env = self.env
else:
env = self.env.clone(True)
order = env.toposort()
thunks = [op.perform for op in order] thunks = [op.perform for op in order]
def f(): def f():
for thunk in thunks: for thunk in thunks:
thunk() thunk()
self.thunk = f return f, env.inputs, env.outputs
self.order = order
self.thunks = thunks # self.thunk = f
# self.order = order
# self.thunks = thunks
class ProfilePerformLinker(Linker): class ProfilePerformLinker(Linker):
......
...@@ -4,10 +4,13 @@ from op import Op ...@@ -4,10 +4,13 @@ from op import Op
__all__ = ['ModalConstructor', __all__ = ['ModalConstructor',
'add_modal_members', 'add_modal_members',
'ModalWrapper', 'build',
'BuildMode', 'eval',
'EvalMode', 'build_eval',
'BuildEvalMode', # 'ModalWrapper',
# 'BuildMode',
# 'EvalMode',
# 'BuildEvalMode',
'make_constructors', 'make_constructors',
] ]
...@@ -15,27 +18,41 @@ class ModalConstructor: ...@@ -15,27 +18,41 @@ class ModalConstructor:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
def __call__(self, *args): def __call__(self, *args):
modal_wrapper = None modal_wrapper = None
fn_args = [] fn_args = []
for arg in args: for arg in args:
if isinstance(arg, ModalWrapper): mode = getattr(arg, '__mode__', False)
if mode:
if modal_wrapper is None: if modal_wrapper is None:
modal_wrapper = arg.__class__ modal_wrapper = mode
else: else:
if not isinstance(arg, modal_wrapper): if mode != modal_wrapper:
raise TypeError("Inconsistent modes.") raise TypeError("Inconsistent modes.")
fn_args.append(arg.r) fn_args.append(arg)
else: # for arg in args:
fn_args.append(arg) # if isinstance(arg, ModalWrapper):
# if modal_wrapper is None:
# modal_wrapper = arg.__class__
# else:
# if not isinstance(arg, modal_wrapper):
# raise TypeError("Inconsistent modes.")
# fn_args.append(arg.r)
# else:
# fn_args.append(arg)
op = self.fn(*fn_args) op = self.fn(*fn_args)
if modal_wrapper: if modal_wrapper:
modal_wrapper.filter(op) modal_wrapper(op)
# modal_wrapper.filter(op)
for output in op.outputs:
output.__mode__ = modal_wrapper
if len(op.outputs) == 1: if len(op.outputs) == 1:
return modal_wrapper(op.outputs[0]) return op.outputs[0]
#return modal_wrapper(op.outputs[0])
else: else:
return [modal_wrapper(output) for output in op.outputs] return op.outputs
#return [modal_wrapper(output) for output in op.outputs]
def add_modal_members(cls, *members): def add_modal_members(cls, *members):
...@@ -48,39 +65,69 @@ def add_modal_members(cls, *members): ...@@ -48,39 +65,69 @@ def add_modal_members(cls, *members):
setattr(cls, member, fn(member)) setattr(cls, member, fn(member))
class ModalWrapper: # class ModalWrapper:
def __init__(self, r): # def __init__(self, r):
self.r = r # self.r = r
@classmethod # def __as_result__(self):
def filter(cls, op): # return self.r
raise AbstractFunctionError()
members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ') # def __get_owner(self):
members = [] # return self.r.owner
members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
members += ["__r%s__" % x for x in members1] # owner = property(__get_owner)
add_modal_members(ModalWrapper, *members)
# @classmethod
# def filter(cls, op):
# raise AbstractFunctionError()
# members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ')
# members = []
# members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
# members += ["__r%s__" % x for x in members1]
# add_modal_members(ModalWrapper, *members)
class BuildMode(ModalWrapper):
@classmethod
def filter(cls, op):
pass
class EvalMode(ModalWrapper): def build_mode(op):
@classmethod pass
def filter(cls, op):
op.perform() def eval_mode(op):
for output in op.outputs: op.perform()
output._role = None for output in op.outputs:
output._role = None
def build_eval_mode(op):
op.perform()
def mode_setter(mode):
def f(r):
r.__mode__ = mode
return r
return f
build = mode_setter(build_mode)
eval = mode_setter(eval_mode)
build_eval = mode_setter(build_eval_mode)
# class BuildMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# pass
# class EvalMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# op.perform()
# for output in op.outputs:
# output._role = None
class BuildEvalMode(ModalWrapper): # class BuildEvalMode(ModalWrapper):
@classmethod # @classmethod
def filter(cls, op): # def filter(cls, op):
op.perform() # op.perform()
def _is_op(x): def _is_op(x):
......
...@@ -64,6 +64,7 @@ class ResultBase(object): ...@@ -64,6 +64,7 @@ class ResultBase(object):
self.__set_data(data) self.__set_data(data)
self.name = name self.name = name
# #
# role # role
# #
...@@ -123,7 +124,7 @@ class ResultBase(object): ...@@ -123,7 +124,7 @@ class ResultBase(object):
self.state = Empty self.state = Empty
return return
try: try:
self.validate(data) data = self.filter(data)
except AbstractFunctionError: except AbstractFunctionError:
pass pass
self._data[0] = data self._data[0] = data
...@@ -132,7 +133,7 @@ class ResultBase(object): ...@@ -132,7 +133,7 @@ class ResultBase(object):
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def validate(self, data): def filter(self, data):
"""(abstract) Raise an exception if the data is not of an """(abstract) Raise an exception if the data is not of an
acceptable type. acceptable type.
...@@ -140,7 +141,8 @@ class ResultBase(object): ...@@ -140,7 +141,8 @@ class ResultBase(object):
it to check that the argument can be used properly. This gives it to check that the argument can be used properly. This gives
a subclass the opportunity to ensure that the contents of a subclass the opportunity to ensure that the contents of
self._data remain sensible. self._data remain sensible.
Returns data or an appropriately wrapped data.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import numpy import numpy
from copy import copy from copy import copy
import inspect
from gof import ResultBase, GuardedOp, utils from gof import ResultBase, GuardedOp, utils
...@@ -17,10 +18,10 @@ def as_scalar(x, name = None): ...@@ -17,10 +18,10 @@ def as_scalar(x, name = None):
class Scalar(ResultBase): class Scalar(ResultBase):
def __init__(self, dtype, name=None): def __init__(self, dtype, data = None, name=None):
self.dtype = dtype self.dtype = dtype
self.constant = False self.constant = False
ResultBase.__init__(self, role = None, data = None, name = name) ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self): def __get_constant(self):
return self._constant return self._constant
...@@ -28,14 +29,13 @@ class Scalar(ResultBase): ...@@ -28,14 +29,13 @@ class Scalar(ResultBase):
def __set_constant(self, value): def __set_constant(self, value):
if value: if value:
self.indestructible = True self.indestructible = True
self.constant = value self._constant = value
constant = property(__get_constant, __set_constant) constant = property(__get_constant, __set_constant)
def validate(self, data): def filter(self, data):
py_type = self.py_type() py_type = self.dtype_specs()[0]
if not isinstance(data, py_type): return py_type(data)
raise TypeError("Expected %s instance." % py_type)
def same_properties(self, other): def same_properties(self, other):
return other.dtype == self.dtype return other.dtype == self.dtype
...@@ -44,51 +44,56 @@ class Scalar(ResultBase): ...@@ -44,51 +44,56 @@ class Scalar(ResultBase):
return getattr(self, 'constant', False) \ return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \ and getattr(other, 'constant', False) \
and self.data == other.data and self.data == other.data
def dtype_specs(self):
return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype]
def py_type(self): # def py_type(self):
return {'float64': float}[self.dtype] # return {'float64': float}[self.dtype]
def c_type(self): # def c_type(self):
return {'float64': 'double'}[self.dtype] # return {'float64': 'double'}[self.dtype]
def c_from(self): # def c_from(self):
return {'float64': 'PyFloat_FromDouble'}[self.dtype] # return {'float64': 'PyFloat_FromDouble'}[self.dtype]
def c_as(self): # def c_as(self):
return {'float64': 'PyFloat_AsDouble'}[self.dtype] # return {'float64': 'PyFloat_AsDouble'}[self.dtype]
def c_declare(self): def c_declare(self):
return """ return """
%(dtype)s* %%(name)s; %(dtype)s %%(name)s;
typedef %(dtype)s %%(name)s_dtype; typedef %(dtype)s %%(name)s_dtype;
""" % dict(dtype = self.c_type()) """ % dict(dtype = self.dtype_specs()[1])
def c_data_extract(self): def c_init(self):
return """ return """
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s); %(name)s = 0;
if (!%%(name)s) """
def c_extract(self):
specs = self.dtype_specs()
return """
if (!%(check)s(py_%%(name)s))
%%(fail)s %%(fail)s
""" % dict(dtype = self.c_type(), %%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s);
conv = self.c_as()) """ % dict(dtype = specs[1],
check = specs[2],
conv = specs[3])
def c_data_sync(self): def c_sync(self):
specs = self.dtype_specs()
return """ return """
Py_XDECREF(py_%%(name)s); Py_XDECREF(py_%%(name)s);
py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s); py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s);
if (!py_%%(name)s) if (!py_%%(name)s)
py_%%(name)s = Py_None; py_%%(name)s = Py_None;
""" % dict(dtype = self.c_type(), """ % dict(dtype = specs[1],
conv = self.c_as()) conv = specs[4])
def c_data_cleanup(self): def c_cleanup(self):
return "" return ""
def c_headers(self):
return []
def c_libraries(self):
return []
class ScalarMixedOp(GuardedOp): class ScalarMixedOp(GuardedOp):
...@@ -120,6 +125,15 @@ class ScalarMixedOp(GuardedOp): ...@@ -120,6 +125,15 @@ class ScalarMixedOp(GuardedOp):
def perform(self): def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs]) self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames)
onames = utils.from_return_values(onames)
return [inames, onames]
def c_code(self):
return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
......
...@@ -23,7 +23,7 @@ class Mul(BinaryScalarOp): ...@@ -23,7 +23,7 @@ class Mul(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x * y return x * y
def c_impl(self, (x, y), z): def c_impl(self, (x, y), z):
return "%(z)s = %(x)s + %(y)s;" return "%(z)s = %(x)s * %(y)s;"
def grad(self, (x, y), gz): def grad(self, (x, y), gz):
return mul(y, gz), mul(x, gz) return mul(y, gz), mul(x, gz)
......
...@@ -21,11 +21,12 @@ class NumpyR(ResultBase): ...@@ -21,11 +21,12 @@ class NumpyR(ResultBase):
elif not str(data.dtype) == self.dtype: elif not str(data.dtype) == self.dtype:
raise TypeError("Expected ndarray with data type %i." % self.dtype) raise TypeError("Expected ndarray with data type %i." % self.dtype)
def to_c_type(self, dtype):
if dtype == "float64": # def to_c_type(self, dtype):
return "double" # if dtype == "float64":
else: # return "double"
raise TypeError("Cannot translate dtype to C.") # else:
# raise TypeError("Cannot translate dtype to C.")
def c_declare(self): def c_declare(self):
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论