added scalar.composite

上级 e5b8c40c
...@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase):
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
class _test_composite(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
C = composite([x, y], [e])
c = C(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 1.5
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
def test_with_constants(self):
x, y, z = inputs()
e = mul(add(70.0, y), div(x, y))
C = composite([x, y], [e])
c = C(x, y)
assert "70.0" in c.c_code(['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 36.0
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self):
x, y, z = inputs()
e0 = x + y + z
e1 = x + y * z
e2 = x / y
C = composite([x, y, z], [e0, e1, e2])
c = C(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5
g = env([x, y], c.outputs)
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == [6.0, 7.0, 0.5]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -409,7 +409,7 @@ class CLinker(Linker): ...@@ -409,7 +409,7 @@ class CLinker(Linker):
elif result in self.orphans: elif result in self.orphans:
self.orphans.remove(result) self.orphans.remove(result)
continue continue
except AbstractFunctionError: except (AbstractFunctionError, NotImplementedError):
pass pass
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction], # policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]] # [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
......
...@@ -5,7 +5,8 @@ import math ...@@ -5,7 +5,8 @@ import math
from copy import copy from copy import copy
import inspect import inspect
from gof import Result, GuardedOp, utils import gof
from gof import Result, GuardedOp, Env, utils
def as_scalar(x, name = None): def as_scalar(x, name = None):
...@@ -29,6 +30,8 @@ class Scalar(Result): ...@@ -29,6 +30,8 @@ class Scalar(Result):
self.dtype_specs() self.dtype_specs()
def __get_constant(self): def __get_constant(self):
if not hasattr(self, '_constant'):
return False
return self._constant return self._constant
def __set_constant(self, value): def __set_constant(self, value):
...@@ -58,6 +61,11 @@ class Scalar(Result): ...@@ -58,6 +61,11 @@ class Scalar(Result):
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
def c_literal(self):
if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.")
return str(self.data)
def c_declare(self, name, sub): def c_declare(self, name, sub):
return """ return """
%(dtype)s %(name)s; %(dtype)s %(name)s;
...@@ -184,7 +192,7 @@ class ScalarMixedOp(GuardedOp): ...@@ -184,7 +192,7 @@ class ScalarMixedOp(GuardedOp):
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes)) o_dtypes = self.propagate_dtypes(*i_dtypes)
self.inputs = inputs self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes] self.outputs = [Scalar(dtype) for dtype in o_dtypes]
...@@ -217,7 +225,7 @@ class PureScalarOp(ScalarMixedOp): ...@@ -217,7 +225,7 @@ class PureScalarOp(ScalarMixedOp):
for dtype in i_dtypes: for dtype in i_dtypes:
if dtype is None: if dtype is None:
raise TypeError("Expected a Scalar.") raise TypeError("Expected a Scalar.")
return self.cast_method(*i_dtypes) return [self.cast_method(*i_dtypes)] * self.nout
class UnaryScalarOp(PureScalarOp): class UnaryScalarOp(PureScalarOp):
...@@ -378,4 +386,105 @@ modes.make_constructors(globals()) ...@@ -378,4 +386,105 @@ modes.make_constructors(globals())
def composite(inputs, outputs):
"""
Usage: composite(inputs, outputs)
Produces an Op class which represents the computations
between the provided inputs and outputs as a single
operation.
The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of
PureScalarOp.
Examples:
x, y = Scalar(), Scalar()
SquareDiff = composite([x, y], [(x - y)**2])
TimesTen = composite([x], [x * 10.0])
Neighbors = composite([x], [x - 1, x + 1])
"""
env = Env(inputs, outputs).clone()
gof.opt.ConstantFinder().apply(env)
inputs, outputs = env.inputs, env.outputs
for op in env.ops():
if not isinstance(op, PureScalarOp):
raise ValueError("The input env to composite must be exclusively composed of PureScalarOp instances.")
subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) +
zip(outputs,
["%%(o%i)s"%i for i in range(len(outputs))]))
for orphan in env.orphans():
if orphan.constant:
subd[orphan] = orphan.c_literal()
else:
raise ValueError("All orphans in the input env to composite must be constant.")
_c_code = "{\n"
i = 0
j = 0
for op in env.toposort():
j += 1
for output in op.outputs:
if output not in subd:
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
_c_code += op.c_code([subd[input] for input in op.inputs],
[subd[output] for output in op.outputs],
dict(fail = "%(fail)s",
id = "%%(id)s_%i" % j))
_c_code += "\n"
_c_code += "}\n"
def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops
if r in env.inputs:
idx = env.inputs.index(r)
return lambda inputs: inputs[idx]
elif r in env.orphans():
return lambda inputs: r.data
op = r.owner
producers = [compose_impl(input) for input in op.inputs]
return lambda inputs: op.impl(*[p(inputs) for p in producers])
_impls = [compose_impl(r) for r in env.outputs]
class Composite(PureScalarOp):
nin = len(inputs)
nout = len(outputs)
# todo: propagate_dtypes?
def perform(self):
inputs = [input.data for input in self.inputs]
for output, impl in zip(self.outputs, _impls):
output.data = impl(inputs)
def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")
def c_code(self, inames, onames, sub):
d = dict(zip(["i%i"%i for i in range(len(inames))],
inames) +
zip(["o%i"%i for i in range(len(onames))],
onames),
**sub)
return _c_code % d
return Composite
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论