added scalar.composite

上级 e5b8c40c
......@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase):
assert fn(1.0, 2.0) == 1.5
class _test_composite(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
C = composite([x, y], [e])
c = C(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 1.5
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
def test_with_constants(self):
x, y, z = inputs()
e = mul(add(70.0, y), div(x, y))
C = composite([x, y], [e])
c = C(x, y)
assert "70.0" in c.c_code(['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 36.0
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self):
x, y, z = inputs()
e0 = x + y + z
e1 = x + y * z
e2 = x / y
C = composite([x, y, z], [e0, e1, e2])
c = C(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5
g = env([x, y], c.outputs)
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == [6.0, 7.0, 0.5]
if __name__ == '__main__':
unittest.main()
......
......@@ -409,7 +409,7 @@ class CLinker(Linker):
elif result in self.orphans:
self.orphans.remove(result)
continue
except AbstractFunctionError:
except (AbstractFunctionError, NotImplementedError):
pass
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
......
......@@ -5,7 +5,8 @@ import math
from copy import copy
import inspect
from gof import Result, GuardedOp, utils
import gof
from gof import Result, GuardedOp, Env, utils
def as_scalar(x, name = None):
......@@ -29,6 +30,8 @@ class Scalar(Result):
self.dtype_specs()
def __get_constant(self):
if not hasattr(self, '_constant'):
return False
return self._constant
def __set_constant(self, value):
......@@ -58,6 +61,11 @@ class Scalar(Result):
except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
def c_literal(self):
if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.")
return str(self.data)
def c_declare(self, name, sub):
return """
%(dtype)s %(name)s;
......@@ -184,7 +192,7 @@ class ScalarMixedOp(GuardedOp):
inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes))
o_dtypes = self.propagate_dtypes(*i_dtypes)
self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes]
......@@ -217,7 +225,7 @@ class PureScalarOp(ScalarMixedOp):
for dtype in i_dtypes:
if dtype is None:
raise TypeError("Expected a Scalar.")
return self.cast_method(*i_dtypes)
return [self.cast_method(*i_dtypes)] * self.nout
class UnaryScalarOp(PureScalarOp):
......@@ -378,4 +386,105 @@ modes.make_constructors(globals())
def composite(inputs, outputs):
"""
Usage: composite(inputs, outputs)
Produces an Op class which represents the computations
between the provided inputs and outputs as a single
operation.
The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of
PureScalarOp.
Examples:
x, y = Scalar(), Scalar()
SquareDiff = composite([x, y], [(x - y)**2])
TimesTen = composite([x], [x * 10.0])
Neighbors = composite([x], [x - 1, x + 1])
"""
env = Env(inputs, outputs).clone()
gof.opt.ConstantFinder().apply(env)
inputs, outputs = env.inputs, env.outputs
for op in env.ops():
if not isinstance(op, PureScalarOp):
raise ValueError("The input env to composite must be exclusively composed of PureScalarOp instances.")
subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) +
zip(outputs,
["%%(o%i)s"%i for i in range(len(outputs))]))
for orphan in env.orphans():
if orphan.constant:
subd[orphan] = orphan.c_literal()
else:
raise ValueError("All orphans in the input env to composite must be constant.")
_c_code = "{\n"
i = 0
j = 0
for op in env.toposort():
j += 1
for output in op.outputs:
if output not in subd:
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
_c_code += op.c_code([subd[input] for input in op.inputs],
[subd[output] for output in op.outputs],
dict(fail = "%(fail)s",
id = "%%(id)s_%i" % j))
_c_code += "\n"
_c_code += "}\n"
def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops
if r in env.inputs:
idx = env.inputs.index(r)
return lambda inputs: inputs[idx]
elif r in env.orphans():
return lambda inputs: r.data
op = r.owner
producers = [compose_impl(input) for input in op.inputs]
return lambda inputs: op.impl(*[p(inputs) for p in producers])
_impls = [compose_impl(r) for r in env.outputs]
class Composite(PureScalarOp):
nin = len(inputs)
nout = len(outputs)
# todo: propagate_dtypes?
def perform(self):
inputs = [input.data for input in self.inputs]
for output, impl in zip(self.outputs, _impls):
output.data = impl(inputs)
def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")
def c_code(self, inames, onames, sub):
d = dict(zip(["i%i"%i for i in range(len(inames))],
inames) +
zip(["o%i"%i for i in range(len(onames))],
onames),
**sub)
return _c_code % d
return Composite
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论