提交 77a11909 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

......@@ -48,6 +48,15 @@ class _test_inplace_opt(unittest.TestCase):
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
def test_inplace_on_second_argument(self):
x, y, z = inputs()
e0 = x + y
e1 = tensor.mul_inplace(x, z)
g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]"
class _test_dimshuffle_lift(unittest.TestCase):
......
......@@ -133,8 +133,8 @@ class Function:
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single,
profiler=profiler)
self.inputs = inputs
self.outputs = outputs
self.inputs = env.inputs
self.outputs = env.outputs
self.features = features
self.optimizer = optimizer
self.linker_cls = linker_cls
......@@ -189,3 +189,100 @@ def eval_outputs(outputs,
return rval
# StateFunction([x, y], [e], (w, w + lr * bla()))
# class _Function:
# def __init__(self,
# inputs,
# outputs,
# optimizer,
# linker_type = 'py',
# unpack_single = True,
# except_unreachable_input = True,
# disposable_inputs = [],
# borrow_outputs = []):
# _mark_indestructible(outputs)
# if len(inputs) != len(set(inputs)):
# raise Exception('duplicate inputs')
# if len(outputs) != len(set(outputs)):
# raise Exception('duplicate outputs')
# orphans = list(gof.graph.results_and_orphans(inputs, outputs,
# except_unreachable_input=except_unreachable_input)[1])
# orphan_data = eval_outputs(orphans, unpack_single=False)
# env = gof.env.Env(inputs, outputs, features + [gof.EquivTool], consistency_check = True)
# env = env.clone(clone_inputs=True)
# for d, o in zip(orphan_data, [env.equiv(orphan) for orphan in orphans]):
# o.data = d
# # optimize and link the cloned env
# if None is not optimizer:
# optimizer(env)
# linker = linker_cls(env)
# if keep_locals:# useful flag for debugging!
# self.__dict__.update(locals())
# if profiler is None:
# self.fn = linker.make_function(inplace=True,
# unpack_single=unpack_single)
# else:
# self.fn = linker.make_function(inplace=True,
# unpack_single=unpack_single,
# profiler=profiler)
# self.inputs = env.inputs
# self.outputs = env.outputs
# self.features = features
# self.optimizer = optimizer
# self.linker_cls = linker_cls
# self.profiler = profiler
# self.unpack_single = unpack_single
# self.except_unreachable_input = except_unreachable_input
# self.keep_locals = keep_locals
# def __call__(self, *args):
# return self.fn(*args)
# def __copy__(self):
# return Function(self.inputs, self.outputs,
# features = self.features,
# optimizer = self.optimizer,
# linker_cls = self.linker_cls,
# profiler = self.profiler,
# unpack_single = self.unpack_single,
# except_unreachable_input = self.except_unreachable_input,
# keep_locals = self.keep_locals)
# class StateFunction:
# def __init__(self, inputs, outputs, *states):
# in_states, out_states = zip(*states)
# env =
......@@ -356,7 +356,7 @@ class CLinker(Linker):
try: self.op_order = env.toposort()
except AttributeError: self.op_order = [env]
def code_gen(self, reuse_storage = True):
def code_gen(self, do_not_reuse = []): # reuse_storage = True):
"""
Generates code for a struct that does the computation of the env and
stores it in the struct_code field of the instance.
......@@ -370,7 +370,7 @@ class CLinker(Linker):
This method caches its computations.
"""
if getattr(self, 'struct_code', False) and self.reuse_storage == reuse_storage:
if getattr(self, 'struct_code', False) and self.do_not_reuse == do_not_reuse:
return self.struct_code
env = self.env
......@@ -424,7 +424,7 @@ class CLinker(Linker):
elif result in self.temps:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract
# they do not need to be relayed to Python, so we don't sync
if result.c_is_simple() or not reuse_storage:
if result.c_is_simple() or result in do_not_reuse:
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, get_c_cleanup]]
else:
......@@ -433,7 +433,7 @@ class CLinker(Linker):
[get_nothing, get_nothing, get_nothing]]
elif result in self.outputs:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract
if result.c_is_simple() or not reuse_storage:
if result.c_is_simple() or result in do_not_reuse:
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]]
......@@ -513,7 +513,7 @@ class CLinker(Linker):
struct_code %= dict(name = struct_name)
self.struct_code = struct_code
self.reuse_storage = reuse_storage
self.do_not_reuse = do_not_reuse
self.struct_name = struct_name
self.hash = hash
self.args = args
......
......@@ -460,9 +460,7 @@ class PatternDescOptimizer(LocalOptimizer):
class ConstantFinder(Optimizer):
"""
Sets as constant every orphan that is not destroyed
and sets as indestructible every input that is not
destroyed.
Sets as constant every orphan that is not destroyed.
"""
def apply(self, env):
......@@ -471,15 +469,15 @@ class ConstantFinder(Optimizer):
if not env.destroyers(r):
r.indestructible = True
r.constant = True
for r in env.inputs:
if not env.destroyers(r):
r.indestructible = True
# for r in env.inputs:
# if not env.destroyers(r):
# r.indestructible = True
else:
for r in env.orphans():
r.indestructible = True
r.constant = True
for r in env.inputs:
r.indestructible = True
# for r in env.inputs:
# r.indestructible = True
import graph
class MergeOptimizer(Optimizer):
......
......@@ -3,7 +3,8 @@ import numpy
import math
from copy import copy
import inspect
from functools import partial
import gof
from gof import Result, GuardedOp, Env, utils
......@@ -184,6 +185,28 @@ class Scalar(Result):
def __rpow__(self,other): return pow(other,self)
# Easy constructors
def _multi(*fns):
def f2(f, names):
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
else:
return [partial(f2, f) for f in fns]
def intr(name):
return Scalar(name = name, dtype = 'int64')
ints = _multi(intr)
def floatr(name):
return Scalar(name = name, dtype = 'float64')
floats = _multi(floatr)
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论