changed the way modes work, finalized a Linker interface, improved cc, some work…

changed the way modes work, finalized a Linker interface, improved cc, some work on scalar/scalar_op and thinking I should commit more often
上级 6ec9db61
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
import unittest import unittest
from gof import ResultBase, Op, Env, modes from gof import ResultBase, Op, Env, modes
import gof
from scalar_ops import * from scalar_ops import *
def inputs(): def inputs():
x = modes.BuildEvalMode(as_scalar(1.0, 'x')) x = modes.build_eval(as_scalar(1.0, 'x'))
y = modes.BuildEvalMode(as_scalar(2.0, 'y')) y = modes.build_eval(as_scalar(2.0, 'y'))
z = modes.BuildEvalMode(as_scalar(3.0, 'z')) z = modes.build_eval(as_scalar(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
...@@ -24,7 +24,15 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -24,7 +24,15 @@ class _test_ScalarOps(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
assert e.r.data == 1.5 assert e.data == 1.5
def test_1(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
g = env([x, y], [e])
fn = gof.cc.CLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
assert e.data == 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
import op, result, ext, link, env, features, toolbox, graph import op, result, ext, link, env, features, toolbox, graph, cc
from op import * from op import *
from result import * from result import *
...@@ -8,7 +8,7 @@ from link import * ...@@ -8,7 +8,7 @@ from link import *
from env import * from env import *
from features import * from features import *
from toolbox import * from toolbox import *
from cc import *
......
...@@ -105,14 +105,14 @@ import modes ...@@ -105,14 +105,14 @@ import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(Double(1.0, 'x')) x = modes.build(Double(1.0, 'x'))
y = modes.BuildMode(Double(2.0, 'y')) y = modes.build(Double(2.0, 'y'))
z = modes.BuildMode(Double(3.0, 'z')) z = modes.build(Double(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
...@@ -121,21 +121,38 @@ class _test_CLinker(unittest.TestCase): ...@@ -121,21 +121,38 @@ class _test_CLinker(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]), [x.r, y.r, z.r], [e.r]) lnk = CLinker(env([x, y, z], [e])) #, [x.r, y.r, z.r], [e.r])
cgen = lnk.code_gen() cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r, z.r], [e.r]) fn = lnk.make_function() #[x.r, y.r, z.r], [e.r])
print fn(2.0, 2.0, 2.0) print fn(2.0, 2.0, 2.0)
# fn = 0 # fn = 0
def test_1(self): def test_1(self):
x, y, z = inputs() x, y, z = inputs()
z.r.constant = True z.constant = True
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e]), [x.r, y.r], [e.r]) lnk = CLinker(env([x, y], [e])) #, [x.r, y.r], [e.r])
cgen = lnk.code_gen() cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r], [e.r]) fn = lnk.make_function() #[x.r, y.r], [e.r])
print fn(2.0, 2.0) print fn(2.0, 2.0)
# fn = 0 # fn = 0
def test_2(self):
x, y, z = inputs()
op = Add(x, y)
lnk = CLinker(op)
cgen = lnk.code_gen()
fn = lnk.make_function() #[x.r, y.r], [op.out])
print fn(2.0, 7.0)
# fn = 0
def test_3(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e]))
fn = lnk.make_function()
print fn(2.0, 2.0, 2.0)
# fn = 0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -81,15 +81,21 @@ s2t = OpSubOptimizer(Sigmoid, TransposeView) ...@@ -81,15 +81,21 @@ s2t = OpSubOptimizer(Sigmoid, TransposeView)
import modes import modes
modes.make_constructors(globals(), name_filter = lambda x:x) modes.make_constructors(globals(), name_filter = lambda x:x)
# def inputs():
# x = modes.BuildMode(MyResult('x'))
# y = modes.BuildMode(MyResult('y'))
# z = modes.BuildMode(MyResult('z'))
# return x, y, z
def inputs(): def inputs():
x = modes.BuildMode(MyResult('x')) x = modes.build(MyResult('x'))
y = modes.BuildMode(MyResult('y')) y = modes.build(MyResult('y'))
z = modes.BuildMode(MyResult('z')) z = modes.build(MyResult('z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs] inputs = [input for input in inputs]
outputs = [output.r for output in outputs] outputs = [output for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
...@@ -139,7 +145,7 @@ class _test_all(unittest.TestCase): ...@@ -139,7 +145,7 @@ class _test_all(unittest.TestCase):
e = Dot(AddInPlace(x,y), TransposeView(x)) e = Dot(AddInPlace(x,y), TransposeView(x))
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e.r.owner.inputs[1], Add(x,z).r) g.replace(e.owner.inputs[1], Add(x,z))
assert g.consistent() assert g.consistent()
def test_5(self): def test_5(self):
...@@ -147,7 +153,7 @@ class _test_all(unittest.TestCase): ...@@ -147,7 +153,7 @@ class _test_all(unittest.TestCase):
e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x)))))) e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x))))))
g = env([x,y,z], [e]) g = env([x,y,z], [e])
assert g.consistent() assert g.consistent()
g.replace(e.r.owner.inputs[1].owner.inputs[0], x.r, False) g.replace(e.owner.inputs[1].owner.inputs[0], x, False)
assert not g.consistent() assert not g.consistent()
def test_6(self): def test_6(self):
...@@ -168,9 +174,9 @@ class _test_all(unittest.TestCase): ...@@ -168,9 +174,9 @@ class _test_all(unittest.TestCase):
chk = g.checkpoint() chk = g.checkpoint()
dtv_elim.optimize(g) dtv_elim.optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
g.replace(g.equiv(e.r), Add(x,y).r) g.replace(g.equiv(e), Add(x,y))
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e.r), Dot(AddInPlace(x,y), TransposeView(x)).r, False) g.replace(g.equiv(e), Dot(AddInPlace(x,y), TransposeView(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent() assert not g.consistent()
g.revert(chk) g.revert(chk)
...@@ -188,21 +194,21 @@ class _test_all(unittest.TestCase): ...@@ -188,21 +194,21 @@ class _test_all(unittest.TestCase):
def test_9(self): def test_9(self):
x, y, z = inputs() x, y, z = inputs()
x.r.indestructible = True x.indestructible = True
e = AddInPlace(x, y) e = AddInPlace(x, y)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e.r, Add(x, y).r) g.replace(e, Add(x, y))
assert g.consistent() assert g.consistent()
def test_10(self): def test_10(self):
x, y, z = inputs() x, y, z = inputs()
x.r.indestructible = True x.indestructible = True
tv = TransposeView(x) tv = TransposeView(x)
e = AddInPlace(tv, y) e = AddInPlace(tv, y)
g = env([x,y,z], [e], False) g = env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv.r, Sigmoid(x).r) g.replace(tv, Sigmoid(x))
assert g.consistent() assert g.consistent()
......
...@@ -65,19 +65,18 @@ import modes ...@@ -65,19 +65,18 @@ import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(Double(1.0, 'x')) x = modes.build(Double(1.0, 'x'))
y = modes.BuildMode(Double(2.0, 'y')) y = modes.build(Double(2.0, 'y'))
z = modes.BuildMode(Double(3.0, 'z')) z = modes.build(Double(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker(env) lnk = PerformLinker(env)
lnk.compile()
return lnk return lnk
...@@ -86,8 +85,23 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -86,8 +85,23 @@ class _test_PerformLinker(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
perform_linker(env([x, y, z], [e])).run() fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True)
assert e.r.data == 1.5 fn()
assert e.data == 1.5
def test_1(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn()
assert e.data != 1.5
def test_2(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,7 +22,7 @@ class Double(ResultBase): ...@@ -22,7 +22,7 @@ class Double(ResultBase):
return self.name return self.name
def __add__(self, other): def __add__(self, other):
return Add(self, other) return add(self, other)
def convert(x): def convert(x):
...@@ -80,51 +80,51 @@ def inputs(mode): ...@@ -80,51 +80,51 @@ def inputs(mode):
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [], consistency_check = validate) return Env(inputs, outputs, features = [], consistency_check = validate)
class _test_Modes(unittest.TestCase): class _test_Modes(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs(BuildMode) x, y, z = inputs(build)
e = add(add(x, y), z) e = add(add(x, y), z)
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 0.0 assert e.data == 0.0
def test_1(self): def test_1(self):
x, y, z = inputs(BuildEvalMode) x, y, z = inputs(build_eval)
e = add(add(x, y), z) e = add(add(x, y), z)
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 6.0 assert e.data == 6.0
def test_2(self): def test_2(self):
x, y, z = inputs(EvalMode) x, y, z = inputs(eval)
e = add(add(x, y), z) e = add(add(x, y), z)
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add_R]" assert str(g) == "[Add_R]"
assert e.r.data == 6.0 assert e.data == 6.0
def test_3(self): def test_3(self):
x, y, z = inputs(BuildMode) x, y, z = inputs(build)
e = x + y + z e = x + y + z
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 0.0 assert e.data == 0.0
def test_4(self): def test_4(self):
x, y, z = inputs(BuildEvalMode) x, y, z = inputs(build_eval)
e = x + 34.0 e = x + 34.0
g = env([x, y, z], [e]) g = env([x, y, z], [e])
assert str(g) == "[Add(x, oignon)]" assert str(g) == "[Add(x, oignon)]"
assert e.r.data == 35.0 assert e.data == 35.0
def test_5(self): def test_5(self):
xb, yb, zb = inputs(BuildMode) xb, yb, zb = inputs(build)
xe, ye, ze = inputs(EvalMode) xe, ye, ze = inputs(eval)
try: try:
e = xb + ye e = xb + ye
except TypeError: except TypeError:
......
...@@ -49,14 +49,14 @@ modes.make_constructors(globals()) ...@@ -49,14 +49,14 @@ modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(MyResult('x')) x = modes.build(MyResult('x'))
y = modes.BuildMode(MyResult('y')) y = modes.build(MyResult('y'))
z = modes.BuildMode(MyResult('z')) z = modes.build(MyResult('z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
......
...@@ -47,14 +47,14 @@ import modes ...@@ -47,14 +47,14 @@ import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.BuildMode(MyResult('x')) x = modes.build(MyResult('x'))
y = modes.BuildMode(MyResult('y')) y = modes.build(MyResult('y'))
z = modes.BuildMode(MyResult('z')) z = modes.build(MyResult('z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs] # inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs] # outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
...@@ -65,10 +65,10 @@ class _test_EquivTool(unittest.TestCase): ...@@ -65,10 +65,10 @@ class _test_EquivTool(unittest.TestCase):
sx = sigmoid(x) sx = sigmoid(x)
e = add(sx, sigmoid(y)) e = add(sx, sigmoid(y))
g = env([x, y, z], [e], features = [EquivTool]) g = env([x, y, z], [e], features = [EquivTool])
assert g.equiv(sx.r) is sx.r assert g.equiv(sx) is sx
g.replace(sx.r, dot(x, z).r) g.replace(sx, dot(x, z))
assert g.equiv(sx.r) is not sx.r assert g.equiv(sx) is not sx
assert isinstance(g.equiv(sx.r).owner, Dot) assert isinstance(g.equiv(sx).owner, Dot)
......
差异被折叠。
...@@ -75,8 +75,8 @@ class Env(graph.Graph): ...@@ -75,8 +75,8 @@ class Env(graph.Graph):
self._tools = {} self._tools = {}
# The inputs and outputs set bound the subgraph this Env operates on. # The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = set(inputs) self.inputs = list(inputs)
self.outputs = set(outputs) self.outputs = list(outputs)
for feature_class in uniq_features(features): for feature_class in uniq_features(features):
self.add_feature(feature_class, False) self.add_feature(feature_class, False)
...@@ -110,9 +110,9 @@ class Env(graph.Graph): ...@@ -110,9 +110,9 @@ class Env(graph.Graph):
### Public interface ### ### Public interface ###
def add_output(self, output): # def add_output(self, output):
self.outputs.add(output) # self.outputs.add(output)
self.__import_r__([output]) # self.__import_r__([output])
def clients(self, r): def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r." "Set of all the (op, i) pairs such that op.inputs[i] is r."
...@@ -249,8 +249,9 @@ class Env(graph.Graph): ...@@ -249,8 +249,9 @@ class Env(graph.Graph):
new_was_output = True new_was_output = True
if r in self.outputs: if r in self.outputs:
was_output = True was_output = True
self.outputs.remove(r) self.outputs[self.outputs.index(r)] = new_r
self.outputs.add(new_r) # self.outputs.remove(r)
# self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise # The actual replacement operation occurs here. This might raise
# an error. # an error.
...@@ -261,8 +262,9 @@ class Env(graph.Graph): ...@@ -261,8 +262,9 @@ class Env(graph.Graph):
# Restore self.outputs # Restore self.outputs
if was_output: if was_output:
if not new_was_output: if not new_was_output:
self.outputs.remove(new_r) self.outputs[self.outputs.index(new_r)] = r
self.outputs.add(r) # self.outputs.remove(new_r)
# self.outputs.add(r)
# Move back the clients. This should never raise an error. # Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r) self.__move_clients__(clients, new_r, r)
......
...@@ -2,35 +2,80 @@ ...@@ -2,35 +2,80 @@
# from features import Tool # from features import Tool
from utils import AbstractFunctionError from utils import AbstractFunctionError
import utils
class Linker: class Linker:
def __init__(self, env): def __init__(self, env):
self.env = env self.env = env
self.thunk = None
def compile(self): def make_thunk(self, inplace = False):
"""
This function must return a triplet (function, input_results, output_results)
where function is a thunk that operates on the returned results. If inplace
is True, the input_results and output_results lists will be the same as the
inputs and outputs of the graph provided to the Linker. Else, independent
results will be returned.
Example:
e = x + y
env = Env([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(env).make_thunk(inplace)
new_x.data = 1.0
new_y.data = 2.0
fn()
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise AbstractFunctionError() raise AbstractFunctionError()
def run(self): def make_function(self, inplace = False):
self.thunk() """
Returns a function that takes values corresponding to the inputs of the
env used by this Linker and returns values corresponding the the outputs
of that env. If inplace is True, the calculations will operate in the
same storage the env uses, else independent storage will be allocated
for the function.
Example:
e = x + y
env = Env([x, y], [e])
fn = MyLinker(env).make_function(inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
thunk, inputs, outputs = self.make_thunk(inplace)
def execute(*args):
for arg, result in zip(args, inputs):
result.data = arg
thunk()
return utils.to_return_values([result.data for result in outputs])
return execute
def __call__(self):
self.thunk()
class PerformLinker(Linker): class PerformLinker(Linker):
def compile(self): def make_thunk(self, inplace = False):
order = self.env.toposort() if inplace:
env = self.env
else:
env = self.env.clone(True)
order = env.toposort()
thunks = [op.perform for op in order] thunks = [op.perform for op in order]
def f(): def f():
for thunk in thunks: for thunk in thunks:
thunk() thunk()
self.thunk = f return f, env.inputs, env.outputs
self.order = order
self.thunks = thunks # self.thunk = f
# self.order = order
# self.thunks = thunks
class ProfilePerformLinker(Linker): class ProfilePerformLinker(Linker):
......
...@@ -4,10 +4,13 @@ from op import Op ...@@ -4,10 +4,13 @@ from op import Op
__all__ = ['ModalConstructor', __all__ = ['ModalConstructor',
'add_modal_members', 'add_modal_members',
'ModalWrapper', 'build',
'BuildMode', 'eval',
'EvalMode', 'build_eval',
'BuildEvalMode', # 'ModalWrapper',
# 'BuildMode',
# 'EvalMode',
# 'BuildEvalMode',
'make_constructors', 'make_constructors',
] ]
...@@ -15,27 +18,41 @@ class ModalConstructor: ...@@ -15,27 +18,41 @@ class ModalConstructor:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
def __call__(self, *args): def __call__(self, *args):
modal_wrapper = None modal_wrapper = None
fn_args = [] fn_args = []
for arg in args: for arg in args:
if isinstance(arg, ModalWrapper): mode = getattr(arg, '__mode__', False)
if mode:
if modal_wrapper is None: if modal_wrapper is None:
modal_wrapper = arg.__class__ modal_wrapper = mode
else: else:
if not isinstance(arg, modal_wrapper): if mode != modal_wrapper:
raise TypeError("Inconsistent modes.") raise TypeError("Inconsistent modes.")
fn_args.append(arg.r) fn_args.append(arg)
else: # for arg in args:
fn_args.append(arg) # if isinstance(arg, ModalWrapper):
# if modal_wrapper is None:
# modal_wrapper = arg.__class__
# else:
# if not isinstance(arg, modal_wrapper):
# raise TypeError("Inconsistent modes.")
# fn_args.append(arg.r)
# else:
# fn_args.append(arg)
op = self.fn(*fn_args) op = self.fn(*fn_args)
if modal_wrapper: if modal_wrapper:
modal_wrapper.filter(op) modal_wrapper(op)
# modal_wrapper.filter(op)
for output in op.outputs:
output.__mode__ = modal_wrapper
if len(op.outputs) == 1: if len(op.outputs) == 1:
return modal_wrapper(op.outputs[0]) return op.outputs[0]
#return modal_wrapper(op.outputs[0])
else: else:
return [modal_wrapper(output) for output in op.outputs] return op.outputs
#return [modal_wrapper(output) for output in op.outputs]
def add_modal_members(cls, *members): def add_modal_members(cls, *members):
...@@ -48,39 +65,69 @@ def add_modal_members(cls, *members): ...@@ -48,39 +65,69 @@ def add_modal_members(cls, *members):
setattr(cls, member, fn(member)) setattr(cls, member, fn(member))
class ModalWrapper: # class ModalWrapper:
def __init__(self, r): # def __init__(self, r):
self.r = r # self.r = r
@classmethod # def __as_result__(self):
def filter(cls, op): # return self.r
raise AbstractFunctionError()
members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ') # def __get_owner(self):
members = [] # return self.r.owner
members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
members += ["__r%s__" % x for x in members1] # owner = property(__get_owner)
add_modal_members(ModalWrapper, *members)
# @classmethod
# def filter(cls, op):
# raise AbstractFunctionError()
# members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ')
# members = []
# members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
# members += ["__r%s__" % x for x in members1]
# add_modal_members(ModalWrapper, *members)
class BuildMode(ModalWrapper):
@classmethod
def filter(cls, op):
pass
class EvalMode(ModalWrapper): def build_mode(op):
@classmethod pass
def filter(cls, op):
op.perform() def eval_mode(op):
for output in op.outputs: op.perform()
output._role = None for output in op.outputs:
output._role = None
def build_eval_mode(op):
op.perform()
def mode_setter(mode):
def f(r):
r.__mode__ = mode
return r
return f
build = mode_setter(build_mode)
eval = mode_setter(eval_mode)
build_eval = mode_setter(build_eval_mode)
# class BuildMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# pass
# class EvalMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# op.perform()
# for output in op.outputs:
# output._role = None
class BuildEvalMode(ModalWrapper): # class BuildEvalMode(ModalWrapper):
@classmethod # @classmethod
def filter(cls, op): # def filter(cls, op):
op.perform() # op.perform()
def _is_op(x): def _is_op(x):
......
...@@ -64,6 +64,7 @@ class ResultBase(object): ...@@ -64,6 +64,7 @@ class ResultBase(object):
self.__set_data(data) self.__set_data(data)
self.name = name self.name = name
# #
# role # role
# #
...@@ -123,7 +124,7 @@ class ResultBase(object): ...@@ -123,7 +124,7 @@ class ResultBase(object):
self.state = Empty self.state = Empty
return return
try: try:
self.validate(data) data = self.filter(data)
except AbstractFunctionError: except AbstractFunctionError:
pass pass
self._data[0] = data self._data[0] = data
...@@ -132,7 +133,7 @@ class ResultBase(object): ...@@ -132,7 +133,7 @@ class ResultBase(object):
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def validate(self, data): def filter(self, data):
"""(abstract) Raise an exception if the data is not of an """(abstract) Raise an exception if the data is not of an
acceptable type. acceptable type.
...@@ -140,7 +141,8 @@ class ResultBase(object): ...@@ -140,7 +141,8 @@ class ResultBase(object):
it to check that the argument can be used properly. This gives it to check that the argument can be used properly. This gives
a subclass the opportunity to ensure that the contents of a subclass the opportunity to ensure that the contents of
self._data remain sensible. self._data remain sensible.
Returns data or an appropriately wrapped data.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import numpy import numpy
from copy import copy from copy import copy
import inspect
from gof import ResultBase, GuardedOp, utils from gof import ResultBase, GuardedOp, utils
...@@ -17,10 +18,10 @@ def as_scalar(x, name = None): ...@@ -17,10 +18,10 @@ def as_scalar(x, name = None):
class Scalar(ResultBase): class Scalar(ResultBase):
def __init__(self, dtype, name=None): def __init__(self, dtype, data = None, name=None):
self.dtype = dtype self.dtype = dtype
self.constant = False self.constant = False
ResultBase.__init__(self, role = None, data = None, name = name) ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self): def __get_constant(self):
return self._constant return self._constant
...@@ -28,14 +29,13 @@ class Scalar(ResultBase): ...@@ -28,14 +29,13 @@ class Scalar(ResultBase):
def __set_constant(self, value): def __set_constant(self, value):
if value: if value:
self.indestructible = True self.indestructible = True
self.constant = value self._constant = value
constant = property(__get_constant, __set_constant) constant = property(__get_constant, __set_constant)
def validate(self, data): def filter(self, data):
py_type = self.py_type() py_type = self.dtype_specs()[0]
if not isinstance(data, py_type): return py_type(data)
raise TypeError("Expected %s instance." % py_type)
def same_properties(self, other): def same_properties(self, other):
return other.dtype == self.dtype return other.dtype == self.dtype
...@@ -44,51 +44,56 @@ class Scalar(ResultBase): ...@@ -44,51 +44,56 @@ class Scalar(ResultBase):
return getattr(self, 'constant', False) \ return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \ and getattr(other, 'constant', False) \
and self.data == other.data and self.data == other.data
def dtype_specs(self):
return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype]
def py_type(self): # def py_type(self):
return {'float64': float}[self.dtype] # return {'float64': float}[self.dtype]
def c_type(self): # def c_type(self):
return {'float64': 'double'}[self.dtype] # return {'float64': 'double'}[self.dtype]
def c_from(self): # def c_from(self):
return {'float64': 'PyFloat_FromDouble'}[self.dtype] # return {'float64': 'PyFloat_FromDouble'}[self.dtype]
def c_as(self): # def c_as(self):
return {'float64': 'PyFloat_AsDouble'}[self.dtype] # return {'float64': 'PyFloat_AsDouble'}[self.dtype]
def c_declare(self): def c_declare(self):
return """ return """
%(dtype)s* %%(name)s; %(dtype)s %%(name)s;
typedef %(dtype)s %%(name)s_dtype; typedef %(dtype)s %%(name)s_dtype;
""" % dict(dtype = self.c_type()) """ % dict(dtype = self.dtype_specs()[1])
def c_data_extract(self): def c_init(self):
return """ return """
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s); %(name)s = 0;
if (!%%(name)s) """
def c_extract(self):
specs = self.dtype_specs()
return """
if (!%(check)s(py_%%(name)s))
%%(fail)s %%(fail)s
""" % dict(dtype = self.c_type(), %%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s);
conv = self.c_as()) """ % dict(dtype = specs[1],
check = specs[2],
conv = specs[3])
def c_data_sync(self): def c_sync(self):
specs = self.dtype_specs()
return """ return """
Py_XDECREF(py_%%(name)s); Py_XDECREF(py_%%(name)s);
py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s); py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s);
if (!py_%%(name)s) if (!py_%%(name)s)
py_%%(name)s = Py_None; py_%%(name)s = Py_None;
""" % dict(dtype = self.c_type(), """ % dict(dtype = specs[1],
conv = self.c_as()) conv = specs[4])
def c_data_cleanup(self): def c_cleanup(self):
return "" return ""
def c_headers(self):
return []
def c_libraries(self):
return []
class ScalarMixedOp(GuardedOp): class ScalarMixedOp(GuardedOp):
...@@ -120,6 +125,15 @@ class ScalarMixedOp(GuardedOp): ...@@ -120,6 +125,15 @@ class ScalarMixedOp(GuardedOp):
def perform(self): def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs]) self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames)
onames = utils.from_return_values(onames)
return [inames, onames]
def c_code(self):
return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
......
...@@ -23,7 +23,7 @@ class Mul(BinaryScalarOp): ...@@ -23,7 +23,7 @@ class Mul(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x * y return x * y
def c_impl(self, (x, y), z): def c_impl(self, (x, y), z):
return "%(z)s = %(x)s + %(y)s;" return "%(z)s = %(x)s * %(y)s;"
def grad(self, (x, y), gz): def grad(self, (x, y), gz):
return mul(y, gz), mul(x, gz) return mul(y, gz), mul(x, gz)
......
...@@ -21,11 +21,12 @@ class NumpyR(ResultBase): ...@@ -21,11 +21,12 @@ class NumpyR(ResultBase):
elif not str(data.dtype) == self.dtype: elif not str(data.dtype) == self.dtype:
raise TypeError("Expected ndarray with data type %i." % self.dtype) raise TypeError("Expected ndarray with data type %i." % self.dtype)
def to_c_type(self, dtype):
if dtype == "float64": # def to_c_type(self, dtype):
return "double" # if dtype == "float64":
else: # return "double"
raise TypeError("Cannot translate dtype to C.") # else:
# raise TypeError("Cannot translate dtype to C.")
def c_declare(self): def c_declare(self):
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论