提交 9fb994db authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added scalar-creating functions

上级 62bffca3
...@@ -48,6 +48,15 @@ class _test_inplace_opt(unittest.TestCase): ...@@ -48,6 +48,15 @@ class _test_inplace_opt(unittest.TestCase):
inplace_optimizer.optimize(g) inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]" assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
def test_inplace_on_second_argument(self):
x, y, z = inputs()
e0 = x + y
e1 = tensor.mul_inplace(x, z)
g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]"
class _test_dimshuffle_lift(unittest.TestCase): class _test_dimshuffle_lift(unittest.TestCase):
......
...@@ -133,8 +133,8 @@ class Function: ...@@ -133,8 +133,8 @@ class Function:
self.fn = linker.make_function(inplace=True, self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single, unpack_single=unpack_single,
profiler=profiler) profiler=profiler)
self.inputs = inputs self.inputs = env.inputs
self.outputs = outputs self.outputs = env.outputs
self.features = features self.features = features
self.optimizer = optimizer self.optimizer = optimizer
self.linker_cls = linker_cls self.linker_cls = linker_cls
...@@ -189,3 +189,100 @@ def eval_outputs(outputs, ...@@ -189,3 +189,100 @@ def eval_outputs(outputs,
return rval return rval
# StateFunction([x, y], [e], (w, w + lr * bla()))
# class _Function:
# def __init__(self,
# inputs,
# outputs,
# optimizer,
# linker_type = 'py',
# unpack_single = True,
# except_unreachable_input = True,
# disposable_inputs = [],
# borrow_outputs = []):
# _mark_indestructible(outputs)
# if len(inputs) != len(set(inputs)):
# raise Exception('duplicate inputs')
# if len(outputs) != len(set(outputs)):
# raise Exception('duplicate outputs')
# orphans = list(gof.graph.results_and_orphans(inputs, outputs,
# except_unreachable_input=except_unreachable_input)[1])
# orphan_data = eval_outputs(orphans, unpack_single=False)
# env = gof.env.Env(inputs, outputs, features + [gof.EquivTool], consistency_check = True)
# env = env.clone(clone_inputs=True)
# for d, o in zip(orphan_data, [env.equiv(orphan) for orphan in orphans]):
# o.data = d
# # optimize and link the cloned env
# if None is not optimizer:
# optimizer(env)
# linker = linker_cls(env)
# if keep_locals:# useful flag for debugging!
# self.__dict__.update(locals())
# if profiler is None:
# self.fn = linker.make_function(inplace=True,
# unpack_single=unpack_single)
# else:
# self.fn = linker.make_function(inplace=True,
# unpack_single=unpack_single,
# profiler=profiler)
# self.inputs = env.inputs
# self.outputs = env.outputs
# self.features = features
# self.optimizer = optimizer
# self.linker_cls = linker_cls
# self.profiler = profiler
# self.unpack_single = unpack_single
# self.except_unreachable_input = except_unreachable_input
# self.keep_locals = keep_locals
# def __call__(self, *args):
# return self.fn(*args)
# def __copy__(self):
# return Function(self.inputs, self.outputs,
# features = self.features,
# optimizer = self.optimizer,
# linker_cls = self.linker_cls,
# profiler = self.profiler,
# unpack_single = self.unpack_single,
# except_unreachable_input = self.except_unreachable_input,
# keep_locals = self.keep_locals)
# class StateFunction:
# def __init__(self, inputs, outputs, *states):
# in_states, out_states = zip(*states)
# env =
...@@ -356,7 +356,7 @@ class CLinker(Linker): ...@@ -356,7 +356,7 @@ class CLinker(Linker):
try: self.op_order = env.toposort() try: self.op_order = env.toposort()
except AttributeError: self.op_order = [env] except AttributeError: self.op_order = [env]
def code_gen(self, reuse_storage = True): def code_gen(self, do_not_reuse = []): # reuse_storage = True):
""" """
Generates code for a struct that does the computation of the env and Generates code for a struct that does the computation of the env and
stores it in the struct_code field of the instance. stores it in the struct_code field of the instance.
...@@ -370,7 +370,7 @@ class CLinker(Linker): ...@@ -370,7 +370,7 @@ class CLinker(Linker):
This method caches its computations. This method caches its computations.
""" """
if getattr(self, 'struct_code', False) and self.reuse_storage == reuse_storage: if getattr(self, 'struct_code', False) and self.do_not_reuse == do_not_reuse:
return self.struct_code return self.struct_code
env = self.env env = self.env
...@@ -424,7 +424,7 @@ class CLinker(Linker): ...@@ -424,7 +424,7 @@ class CLinker(Linker):
elif result in self.temps: elif result in self.temps:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract # temps don't need to be extracted from Python, so we call c_init rather than c_extract
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync
if result.c_is_simple() or not reuse_storage: if result.c_is_simple() or result in do_not_reuse:
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, get_c_cleanup]] [get_c_declare, get_c_init, get_c_cleanup]]
else: else:
...@@ -433,7 +433,7 @@ class CLinker(Linker): ...@@ -433,7 +433,7 @@ class CLinker(Linker):
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif result in self.outputs: elif result in self.outputs:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract # outputs don't need to be extracted from Python, so we call c_init rather than c_extract
if result.c_is_simple() or not reuse_storage: if result.c_is_simple() or result in do_not_reuse:
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]] [get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]]
...@@ -513,7 +513,7 @@ class CLinker(Linker): ...@@ -513,7 +513,7 @@ class CLinker(Linker):
struct_code %= dict(name = struct_name) struct_code %= dict(name = struct_name)
self.struct_code = struct_code self.struct_code = struct_code
self.reuse_storage = reuse_storage self.do_not_reuse = do_not_reuse
self.struct_name = struct_name self.struct_name = struct_name
self.hash = hash self.hash = hash
self.args = args self.args = args
......
...@@ -460,9 +460,7 @@ class PatternDescOptimizer(LocalOptimizer): ...@@ -460,9 +460,7 @@ class PatternDescOptimizer(LocalOptimizer):
class ConstantFinder(Optimizer): class ConstantFinder(Optimizer):
""" """
Sets as constant every orphan that is not destroyed Sets as constant every orphan that is not destroyed.
and sets as indestructible every input that is not
destroyed.
""" """
def apply(self, env): def apply(self, env):
...@@ -471,15 +469,15 @@ class ConstantFinder(Optimizer): ...@@ -471,15 +469,15 @@ class ConstantFinder(Optimizer):
if not env.destroyers(r): if not env.destroyers(r):
r.indestructible = True r.indestructible = True
r.constant = True r.constant = True
for r in env.inputs: # for r in env.inputs:
if not env.destroyers(r): # if not env.destroyers(r):
r.indestructible = True # r.indestructible = True
else: else:
for r in env.orphans(): for r in env.orphans():
r.indestructible = True r.indestructible = True
r.constant = True r.constant = True
for r in env.inputs: # for r in env.inputs:
r.indestructible = True # r.indestructible = True
import graph import graph
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
......
...@@ -3,7 +3,8 @@ import numpy ...@@ -3,7 +3,8 @@ import numpy
import math import math
from copy import copy from copy import copy
import inspect
from functools import partial
import gof import gof
from gof import Result, GuardedOp, Env, utils from gof import Result, GuardedOp, Env, utils
...@@ -184,6 +185,28 @@ class Scalar(Result): ...@@ -184,6 +185,28 @@ class Scalar(Result):
def __rpow__(self,other): return pow(other,self) def __rpow__(self,other): return pow(other,self)
# Easy constructors
def _multi(*fns):
def f2(f, names):
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
else:
return [partial(f2, f) for f in fns]
def intr(name):
return Scalar(name = name, dtype = 'int64')
ints = _multi(intr)
def floatr(name):
return Scalar(name = name, dtype = 'float64')
floats = _multi(floatr)
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论