changed the way modes work, finalized a Linker interface, improved cc, some work…

changed the way modes work, finalized a Linker interface, improved cc, some work on scalar/scalar_op and thinking I should commit more often
上级 6ec9db61
......@@ -2,20 +2,20 @@
import unittest
from gof import ResultBase, Op, Env, modes
import gof
from scalar_ops import *
def inputs():
x = modes.BuildEvalMode(as_scalar(1.0, 'x'))
y = modes.BuildEvalMode(as_scalar(2.0, 'y'))
z = modes.BuildEvalMode(as_scalar(3.0, 'z'))
x = modes.build_eval(as_scalar(1.0, 'x'))
y = modes.build_eval(as_scalar(2.0, 'y'))
z = modes.build_eval(as_scalar(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
......@@ -24,7 +24,15 @@ class _test_ScalarOps(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
assert e.r.data == 1.5
assert e.data == 1.5
def test_1(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
g = env([x, y], [e])
fn = gof.cc.CLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
assert e.data == 1.5
if __name__ == '__main__':
......
import op, result, ext, link, env, features, toolbox, graph
import op, result, ext, link, env, features, toolbox, graph, cc
from op import *
from result import *
......@@ -8,7 +8,7 @@ from link import *
from env import *
from features import *
from toolbox import *
from cc import *
......
......@@ -105,14 +105,14 @@ import modes
modes.make_constructors(globals())
def inputs():
x = modes.BuildMode(Double(1.0, 'x'))
y = modes.BuildMode(Double(2.0, 'y'))
z = modes.BuildMode(Double(3.0, 'z'))
x = modes.build(Double(1.0, 'x'))
y = modes.build(Double(2.0, 'y'))
z = modes.build(Double(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
......@@ -121,21 +121,38 @@ class _test_CLinker(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]), [x.r, y.r, z.r], [e.r])
lnk = CLinker(env([x, y, z], [e])) #, [x.r, y.r, z.r], [e.r])
cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r, z.r], [e.r])
fn = lnk.make_function() #[x.r, y.r, z.r], [e.r])
print fn(2.0, 2.0, 2.0)
# fn = 0
def test_1(self):
x, y, z = inputs()
z.r.constant = True
z.constant = True
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e]), [x.r, y.r], [e.r])
lnk = CLinker(env([x, y], [e])) #, [x.r, y.r], [e.r])
cgen = lnk.code_gen()
fn = lnk.make_function([x.r, y.r], [e.r])
fn = lnk.make_function() #[x.r, y.r], [e.r])
print fn(2.0, 2.0)
# fn = 0
def test_2(self):
x, y, z = inputs()
op = Add(x, y)
lnk = CLinker(op)
cgen = lnk.code_gen()
fn = lnk.make_function() #[x.r, y.r], [op.out])
print fn(2.0, 7.0)
# fn = 0
def test_3(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e]))
fn = lnk.make_function()
print fn(2.0, 2.0, 2.0)
# fn = 0
if __name__ == '__main__':
unittest.main()
......@@ -81,15 +81,21 @@ s2t = OpSubOptimizer(Sigmoid, TransposeView)
import modes
modes.make_constructors(globals(), name_filter = lambda x:x)
# def inputs():
# x = modes.BuildMode(MyResult('x'))
# y = modes.BuildMode(MyResult('y'))
# z = modes.BuildMode(MyResult('z'))
# return x, y, z
def inputs():
x = modes.BuildMode(MyResult('x'))
y = modes.BuildMode(MyResult('y'))
z = modes.BuildMode(MyResult('z'))
x = modes.build(MyResult('x'))
y = modes.build(MyResult('y'))
z = modes.build(MyResult('z'))
return x, y, z
def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
inputs = [input for input in inputs]
outputs = [output for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
......@@ -139,7 +145,7 @@ class _test_all(unittest.TestCase):
e = Dot(AddInPlace(x,y), TransposeView(x))
g = env([x,y,z], [e], False)
assert not g.consistent()
g.replace(e.r.owner.inputs[1], Add(x,z).r)
g.replace(e.owner.inputs[1], Add(x,z))
assert g.consistent()
def test_5(self):
......@@ -147,7 +153,7 @@ class _test_all(unittest.TestCase):
e = Dot(AddInPlace(x,y), TransposeView(TransposeView(TransposeView(TransposeView(Sigmoid(x))))))
g = env([x,y,z], [e])
assert g.consistent()
g.replace(e.r.owner.inputs[1].owner.inputs[0], x.r, False)
g.replace(e.owner.inputs[1].owner.inputs[0], x, False)
assert not g.consistent()
def test_6(self):
......@@ -168,9 +174,9 @@ class _test_all(unittest.TestCase):
chk = g.checkpoint()
dtv_elim.optimize(g)
assert str(g) == "[x]"
g.replace(g.equiv(e.r), Add(x,y).r)
g.replace(g.equiv(e), Add(x,y))
assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e.r), Dot(AddInPlace(x,y), TransposeView(x)).r, False)
g.replace(g.equiv(e), Dot(AddInPlace(x,y), TransposeView(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent()
g.revert(chk)
......@@ -188,21 +194,21 @@ class _test_all(unittest.TestCase):
def test_9(self):
x, y, z = inputs()
x.r.indestructible = True
x.indestructible = True
e = AddInPlace(x, y)
g = env([x,y,z], [e], False)
assert not g.consistent()
g.replace(e.r, Add(x, y).r)
g.replace(e, Add(x, y))
assert g.consistent()
def test_10(self):
x, y, z = inputs()
x.r.indestructible = True
x.indestructible = True
tv = TransposeView(x)
e = AddInPlace(tv, y)
g = env([x,y,z], [e], False)
assert not g.consistent()
g.replace(tv.r, Sigmoid(x).r)
g.replace(tv, Sigmoid(x))
assert g.consistent()
......
......@@ -65,19 +65,18 @@ import modes
modes.make_constructors(globals())
def inputs():
x = modes.BuildMode(Double(1.0, 'x'))
y = modes.BuildMode(Double(2.0, 'y'))
z = modes.BuildMode(Double(3.0, 'z'))
x = modes.build(Double(1.0, 'x'))
y = modes.build(Double(2.0, 'y'))
z = modes.build(Double(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env):
lnk = PerformLinker(env)
lnk.compile()
return lnk
......@@ -86,8 +85,23 @@ class _test_PerformLinker(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
perform_linker(env([x, y, z], [e])).run()
assert e.r.data == 1.5
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True)
fn()
assert e.data == 1.5
def test_1(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn()
assert e.data != 1.5
def test_2(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5
if __name__ == '__main__':
......
......@@ -22,7 +22,7 @@ class Double(ResultBase):
return self.name
def __add__(self, other):
return Add(self, other)
return add(self, other)
def convert(x):
......@@ -80,51 +80,51 @@ def inputs(mode):
return x, y, z
def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [], consistency_check = validate)
class _test_Modes(unittest.TestCase):
def test_0(self):
x, y, z = inputs(BuildMode)
x, y, z = inputs(build)
e = add(add(x, y), z)
g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 0.0
assert e.data == 0.0
def test_1(self):
x, y, z = inputs(BuildEvalMode)
x, y, z = inputs(build_eval)
e = add(add(x, y), z)
g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 6.0
assert e.data == 6.0
def test_2(self):
x, y, z = inputs(EvalMode)
x, y, z = inputs(eval)
e = add(add(x, y), z)
g = env([x, y, z], [e])
assert str(g) == "[Add_R]"
assert e.r.data == 6.0
assert e.data == 6.0
def test_3(self):
x, y, z = inputs(BuildMode)
x, y, z = inputs(build)
e = x + y + z
g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]"
assert e.r.data == 0.0
assert e.data == 0.0
def test_4(self):
x, y, z = inputs(BuildEvalMode)
x, y, z = inputs(build_eval)
e = x + 34.0
g = env([x, y, z], [e])
assert str(g) == "[Add(x, oignon)]"
assert e.r.data == 35.0
assert e.data == 35.0
def test_5(self):
xb, yb, zb = inputs(BuildMode)
xe, ye, ze = inputs(EvalMode)
xb, yb, zb = inputs(build)
xe, ye, ze = inputs(eval)
try:
e = xb + ye
except TypeError:
......
......@@ -49,14 +49,14 @@ modes.make_constructors(globals())
def inputs():
x = modes.BuildMode(MyResult('x'))
y = modes.BuildMode(MyResult('y'))
z = modes.BuildMode(MyResult('z'))
x = modes.build(MyResult('x'))
y = modes.build(MyResult('y'))
z = modes.build(MyResult('z'))
return x, y, z
def env(inputs, outputs, validate = True):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
......
......@@ -47,14 +47,14 @@ import modes
modes.make_constructors(globals())
def inputs():
x = modes.BuildMode(MyResult('x'))
y = modes.BuildMode(MyResult('y'))
z = modes.BuildMode(MyResult('z'))
x = modes.build(MyResult('x'))
y = modes.build(MyResult('y'))
z = modes.build(MyResult('z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
inputs = [input.r for input in inputs]
outputs = [output.r for output in outputs]
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
......@@ -65,10 +65,10 @@ class _test_EquivTool(unittest.TestCase):
sx = sigmoid(x)
e = add(sx, sigmoid(y))
g = env([x, y, z], [e], features = [EquivTool])
assert g.equiv(sx.r) is sx.r
g.replace(sx.r, dot(x, z).r)
assert g.equiv(sx.r) is not sx.r
assert isinstance(g.equiv(sx.r).owner, Dot)
assert g.equiv(sx) is sx
g.replace(sx, dot(x, z))
assert g.equiv(sx) is not sx
assert isinstance(g.equiv(sx).owner, Dot)
......
差异被折叠。
......@@ -75,8 +75,8 @@ class Env(graph.Graph):
self._tools = {}
# The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = set(inputs)
self.outputs = set(outputs)
self.inputs = list(inputs)
self.outputs = list(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
......@@ -110,9 +110,9 @@ class Env(graph.Graph):
### Public interface ###
def add_output(self, output):
self.outputs.add(output)
self.__import_r__([output])
# def add_output(self, output):
# self.outputs.add(output)
# self.__import_r__([output])
def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r."
......@@ -249,8 +249,9 @@ class Env(graph.Graph):
new_was_output = True
if r in self.outputs:
was_output = True
self.outputs.remove(r)
self.outputs.add(new_r)
self.outputs[self.outputs.index(r)] = new_r
# self.outputs.remove(r)
# self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise
# an error.
......@@ -261,8 +262,9 @@ class Env(graph.Graph):
# Restore self.outputs
if was_output:
if not new_was_output:
self.outputs.remove(new_r)
self.outputs.add(r)
self.outputs[self.outputs.index(new_r)] = r
# self.outputs.remove(new_r)
# self.outputs.add(r)
# Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r)
......
......@@ -2,35 +2,80 @@
# from features import Tool
from utils import AbstractFunctionError
import utils
class Linker:
def __init__(self, env):
self.env = env
self.thunk = None
def compile(self):
def make_thunk(self, inplace = False):
"""
This function must return a triplet (function, input_results, output_results)
where function is a thunk that operates on the returned results. If inplace
is True, the input_results and output_results lists will be the same as the
inputs and outputs of the graph provided to the Linker. Else, independent
results will be returned.
Example:
e = x + y
env = Env([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(env).make_thunk(inplace)
new_x.data = 1.0
new_y.data = 2.0
fn()
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise AbstractFunctionError()
def run(self):
self.thunk()
def make_function(self, inplace = False):
"""
Returns a function that takes values corresponding to the inputs of the
env used by this Linker and returns values corresponding the the outputs
of that env. If inplace is True, the calculations will operate in the
same storage the env uses, else independent storage will be allocated
for the function.
Example:
e = x + y
env = Env([x, y], [e])
fn = MyLinker(env).make_function(inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
thunk, inputs, outputs = self.make_thunk(inplace)
def execute(*args):
for arg, result in zip(args, inputs):
result.data = arg
thunk()
return utils.to_return_values([result.data for result in outputs])
return execute
def __call__(self):
self.thunk()
class PerformLinker(Linker):
def compile(self):
order = self.env.toposort()
def make_thunk(self, inplace = False):
if inplace:
env = self.env
else:
env = self.env.clone(True)
order = env.toposort()
thunks = [op.perform for op in order]
def f():
for thunk in thunks:
thunk()
self.thunk = f
self.order = order
self.thunks = thunks
return f, env.inputs, env.outputs
# self.thunk = f
# self.order = order
# self.thunks = thunks
class ProfilePerformLinker(Linker):
......
......@@ -4,10 +4,13 @@ from op import Op
__all__ = ['ModalConstructor',
'add_modal_members',
'ModalWrapper',
'BuildMode',
'EvalMode',
'BuildEvalMode',
'build',
'eval',
'build_eval',
# 'ModalWrapper',
# 'BuildMode',
# 'EvalMode',
# 'BuildEvalMode',
'make_constructors',
]
......@@ -15,27 +18,41 @@ class ModalConstructor:
def __init__(self, fn):
self.fn = fn
def __call__(self, *args):
modal_wrapper = None
fn_args = []
for arg in args:
if isinstance(arg, ModalWrapper):
mode = getattr(arg, '__mode__', False)
if mode:
if modal_wrapper is None:
modal_wrapper = arg.__class__
modal_wrapper = mode
else:
if not isinstance(arg, modal_wrapper):
if mode != modal_wrapper:
raise TypeError("Inconsistent modes.")
fn_args.append(arg.r)
else:
fn_args.append(arg)
fn_args.append(arg)
# for arg in args:
# if isinstance(arg, ModalWrapper):
# if modal_wrapper is None:
# modal_wrapper = arg.__class__
# else:
# if not isinstance(arg, modal_wrapper):
# raise TypeError("Inconsistent modes.")
# fn_args.append(arg.r)
# else:
# fn_args.append(arg)
op = self.fn(*fn_args)
if modal_wrapper:
modal_wrapper.filter(op)
modal_wrapper(op)
# modal_wrapper.filter(op)
for output in op.outputs:
output.__mode__ = modal_wrapper
if len(op.outputs) == 1:
return modal_wrapper(op.outputs[0])
return op.outputs[0]
#return modal_wrapper(op.outputs[0])
else:
return [modal_wrapper(output) for output in op.outputs]
return op.outputs
#return [modal_wrapper(output) for output in op.outputs]
def add_modal_members(cls, *members):
......@@ -48,39 +65,69 @@ def add_modal_members(cls, *members):
setattr(cls, member, fn(member))
class ModalWrapper:
# class ModalWrapper:
def __init__(self, r):
self.r = r
# def __init__(self, r):
# self.r = r
@classmethod
def filter(cls, op):
raise AbstractFunctionError()
# def __as_result__(self):
# return self.r
members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ')
members = []
members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
members += ["__r%s__" % x for x in members1]
add_modal_members(ModalWrapper, *members)
# def __get_owner(self):
# return self.r.owner
# owner = property(__get_owner)
# @classmethod
# def filter(cls, op):
# raise AbstractFunctionError()
# members1 = 'add sub mul div pow floordiv mod pow lshift rshift and or xor'.split(' ')
# members = []
# members += ["__%s__" % x for x in members1 + 'neg invert'.split(' ')]
# members += ["__r%s__" % x for x in members1]
# add_modal_members(ModalWrapper, *members)
class BuildMode(ModalWrapper):
@classmethod
def filter(cls, op):
pass
class EvalMode(ModalWrapper):
@classmethod
def filter(cls, op):
op.perform()
for output in op.outputs:
output._role = None
def build_mode(op):
pass
def eval_mode(op):
op.perform()
for output in op.outputs:
output._role = None
def build_eval_mode(op):
op.perform()
def mode_setter(mode):
def f(r):
r.__mode__ = mode
return r
return f
build = mode_setter(build_mode)
eval = mode_setter(eval_mode)
build_eval = mode_setter(build_eval_mode)
# class BuildMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# pass
# class EvalMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# op.perform()
# for output in op.outputs:
# output._role = None
class BuildEvalMode(ModalWrapper):
@classmethod
def filter(cls, op):
op.perform()
# class BuildEvalMode(ModalWrapper):
# @classmethod
# def filter(cls, op):
# op.perform()
def _is_op(x):
......
......@@ -64,6 +64,7 @@ class ResultBase(object):
self.__set_data(data)
self.name = name
#
# role
#
......@@ -123,7 +124,7 @@ class ResultBase(object):
self.state = Empty
return
try:
self.validate(data)
data = self.filter(data)
except AbstractFunctionError:
pass
self._data[0] = data
......@@ -132,7 +133,7 @@ class ResultBase(object):
data = property(__get_data, __set_data,
doc = "The storage associated with this result")
def validate(self, data):
def filter(self, data):
"""(abstract) Raise an exception if the data is not of an
acceptable type.
......@@ -140,7 +141,8 @@ class ResultBase(object):
it to check that the argument can be used properly. This gives
a subclass the opportunity to ensure that the contents of
self._data remain sensible.
Returns data or an appropriately wrapped data.
"""
raise AbstractFunctionError()
......
......@@ -2,6 +2,7 @@
import numpy
from copy import copy
import inspect
from gof import ResultBase, GuardedOp, utils
......@@ -17,10 +18,10 @@ def as_scalar(x, name = None):
class Scalar(ResultBase):
def __init__(self, dtype, name=None):
def __init__(self, dtype, data = None, name=None):
self.dtype = dtype
self.constant = False
ResultBase.__init__(self, role = None, data = None, name = name)
ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self):
return self._constant
......@@ -28,14 +29,13 @@ class Scalar(ResultBase):
def __set_constant(self, value):
if value:
self.indestructible = True
self.constant = value
self._constant = value
constant = property(__get_constant, __set_constant)
def validate(self, data):
py_type = self.py_type()
if not isinstance(data, py_type):
raise TypeError("Expected %s instance." % py_type)
def filter(self, data):
py_type = self.dtype_specs()[0]
return py_type(data)
def same_properties(self, other):
return other.dtype == self.dtype
......@@ -44,51 +44,56 @@ class Scalar(ResultBase):
return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \
and self.data == other.data
def dtype_specs(self):
return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype]
def py_type(self):
return {'float64': float}[self.dtype]
# def py_type(self):
# return {'float64': float}[self.dtype]
def c_type(self):
return {'float64': 'double'}[self.dtype]
# def c_type(self):
# return {'float64': 'double'}[self.dtype]
def c_from(self):
return {'float64': 'PyFloat_FromDouble'}[self.dtype]
# def c_from(self):
# return {'float64': 'PyFloat_FromDouble'}[self.dtype]
def c_as(self):
return {'float64': 'PyFloat_AsDouble'}[self.dtype]
# def c_as(self):
# return {'float64': 'PyFloat_AsDouble'}[self.dtype]
def c_declare(self):
return """
%(dtype)s* %%(name)s;
%(dtype)s %%(name)s;
typedef %(dtype)s %%(name)s_dtype;
""" % dict(dtype = self.c_type())
""" % dict(dtype = self.dtype_specs()[1])
def c_data_extract(self):
def c_init(self):
return """
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s);
if (!%%(name)s)
%(name)s = 0;
"""
def c_extract(self):
specs = self.dtype_specs()
return """
if (!%(check)s(py_%%(name)s))
%%(fail)s
""" % dict(dtype = self.c_type(),
conv = self.c_as())
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s);
""" % dict(dtype = specs[1],
check = specs[2],
conv = specs[3])
def c_data_sync(self):
def c_sync(self):
specs = self.dtype_specs()
return """
Py_XDECREF(py_%%(name)s);
py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s);
if (!py_%%(name)s)
py_%%(name)s = Py_None;
""" % dict(dtype = self.c_type(),
conv = self.c_as())
""" % dict(dtype = specs[1],
conv = specs[4])
def c_data_cleanup(self):
def c_cleanup(self):
return ""
def c_headers(self):
return []
def c_libraries(self):
return []
class ScalarMixedOp(GuardedOp):
......@@ -120,6 +125,15 @@ class ScalarMixedOp(GuardedOp):
def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames)
onames = utils.from_return_values(onames)
return [inames, onames]
def c_code(self):
return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
......
......@@ -23,7 +23,7 @@ class Mul(BinaryScalarOp):
def impl(self, x, y):
return x * y
def c_impl(self, (x, y), z):
return "%(z)s = %(x)s + %(y)s;"
return "%(z)s = %(x)s * %(y)s;"
def grad(self, (x, y), gz):
return mul(y, gz), mul(x, gz)
......
......@@ -21,11 +21,12 @@ class NumpyR(ResultBase):
elif not str(data.dtype) == self.dtype:
raise TypeError("Expected ndarray with data type %i." % self.dtype)
def to_c_type(self, dtype):
if dtype == "float64":
return "double"
else:
raise TypeError("Cannot translate dtype to C.")
# def to_c_type(self, dtype):
# if dtype == "float64":
# return "double"
# else:
# raise TypeError("Cannot translate dtype to C.")
def c_declare(self):
return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论