提交 ea32b4db authored 作者: Olivier Breuleux's avatar Olivier Breuleux

too many things to list

上级 d2cf55aa
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
......@@ -6,7 +6,8 @@ from cc import *
from type import Type
from graph import Result, as_result, Apply, Constant
from op import Op
from env import Env
import env
import toolbox
class TDouble(Type):
def filter(self, data):
......@@ -125,6 +126,11 @@ def inputs():
return x, y, z
def Env(inputs, outputs):
e = env.Env(inputs, outputs)
return e
class _test_CLinker(unittest.TestCase):
def test_straightforward(self):
......
......@@ -257,7 +257,6 @@ class _test_all(unittest.TestCase):
if __name__ == '__main__':
#unittest.main()
_test_all('test_usage_loop_through_views').debug()
unittest.main()
......@@ -161,14 +161,14 @@ class _test_clone(unittest.TestCase):
def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2)
new = clone([r1, r2], node.outputs)
_, new = clone([r1, r2], node.outputs, False)
assert self.str([r1, r2], new) == ["MyOp(1, 2)"]
def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5)
new = clone([r1, r2, r5], node2.outputs)
_, new = clone([r1, r2, r5], node2.outputs, False)
assert node2.outputs[0].type == new[0].type and node2.outputs[0] is not new[0] # the new output is like the old one but not the same object
assert node2 is not new[0].owner # the new output has a new owner
assert new[0].owner.inputs[1] is r5 # the inputs are not copied
......@@ -178,7 +178,7 @@ class _test_clone(unittest.TestCase):
# Checks that manipulating a cloned graph leaves the original unchanged.
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], node.outputs)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyResult(7), MyResult(8)
......
......@@ -2,10 +2,12 @@
import unittest
from graph import Result, as_result, Apply
import graph
from graph import Result, as_result, Apply, Constant
from type import Type
from op import Op
from env import Env
import env
import toolbox
from link import *
......@@ -67,6 +69,10 @@ def perform_linker(env):
lnk = PerformLinker(env)
return lnk
def Env(inputs, outputs):
e = env.Env(inputs, outputs)
return e
class _test_PerformLinker(unittest.TestCase):
......@@ -94,16 +100,14 @@ class _test_PerformLinker(unittest.TestCase):
def test_input_output_same(self):
x, y, z = inputs()
a,d = add(x,y), div(x,y)
e = mul(a,d)
fn = perform_linker(Env([e], [e])).make_function()
fn = perform_linker(Env([x], [x])).make_function()
self.failUnless(1.0 is fn(1.0))
def test_input_dependency0(self):
x, y, z = inputs()
a,d = add(x,y), div(x,y)
e = mul(a,d)
fn = perform_linker(Env([x, y, a], [e])).make_function()
fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function()
self.failUnless(fn(1.0,2.0,9.0) == 4.5)
def test_skiphole(self):
......@@ -111,9 +115,11 @@ class _test_PerformLinker(unittest.TestCase):
a = add(x,y)
r = raise_err(a)
e = add(r,a)
fn = perform_linker(Env([x, y,r], [e])).make_function()
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5)
# def test_disconnected_input_output(self):
# x,y,z = inputs()
# a = add(x,y)
......
......@@ -415,4 +415,3 @@ if __name__ == '__main__':
unittest.main()
from graph import Constant
import graph
from graph import Constant, Value
from link import Linker, LocalLinker, raise_with_op, Filter, map_storage, PerformLinker
from copy import copy
from utils import AbstractFunctionError
......@@ -284,10 +285,11 @@ def apply_policy(policy, r, name, sub):
@type r: L{Result}
@return: C{policy[0](r) + policy[1](r) + ...}
"""
if isinstance(r, (list, tuple)):
if isinstance(policy, (list, tuple)):
ret = ""
for sub_policy in policy:
ret += sub_policy(r, name, sub)
return ret
return policy(r, name, sub)
......@@ -345,7 +347,7 @@ class CLinker(Linker):
self.outputs = env.outputs
self.results = list(env.results)
# The orphans field is listified to ensure a consistent order.
self.orphans = list(env.orphans.difference(self.outputs))
self.orphans = list(r for r in self.results if isinstance(r, Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.node_order = env.toposort()
......@@ -403,15 +405,16 @@ class CLinker(Linker):
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]]
elif result in self.orphans:
if not isinstance(result, Constant):
raise TypeError("All orphans to CLinker must be Constant.", result)
try:
symbol[result] = "(" + result.type.c_literal(result.data) + ")"
consts.append(result)
self.orphans.remove(result)
continue
except (AbstractFunctionError, NotImplementedError):
pass
if not isinstance(result, Value):
raise TypeError("All orphans to CLinker must be Value instances.", result)
if isinstance(result, Constant):
try:
symbol[result] = "(" + result.type.c_literal(result.data) + ")"
consts.append(result)
self.orphans.remove(result)
continue
except (AbstractFunctionError, NotImplementedError):
pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]]
......@@ -428,7 +431,6 @@ class CLinker(Linker):
elif result in self.outputs:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract
if result.type.c_is_simple() or result in no_recycling:
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]]
else:
......@@ -599,7 +601,12 @@ class CLinker(Linker):
if input_storage is None:
input_storage = [[None] for result in self.inputs]
if output_storage is None:
output_storage = [[None] for result in self.outputs]
map = {}
output_storage = []
for result in self.outputs:
if result not in map:
map[result] = [None]
output_storage.append(map[result])
thunk = self.cthunk_factory(error_storage,
input_storage,
output_storage)
......@@ -642,13 +649,13 @@ class CLinker(Linker):
if not getattr(self, 'instantiate', False):
self.code_gen()
module_name = self.hash
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
cthunk = object() # dummy so weave can get the type
module_name = self.hash
mod = weave.ext_tools.ext_module(module_name)
argnames = ["i%i" % i for i in xrange(len(in_storage))] \
......@@ -710,8 +717,11 @@ class CLinker(Linker):
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
module_name = self.hash
module = __import__("%s" % (module_name), {}, {}, [module_name])
ret = module.instantiate(error_storage, *(in_storage + out_storage + [orphan.data for orphan in self.orphans]))
orphd = [[orphan.data] for orphan in self.orphans]
ret = module.instantiate(error_storage, *(in_storage + out_storage + orphd))
assert sys.getrefcount(ret) == 2 # refcount leak check
return ret
......@@ -751,7 +761,9 @@ class OpWiseCLinker(LocalLinker):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
try:
cl = CLinker(Env(node.inputs, node.outputs))
e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
cl = CLinker(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
......@@ -823,7 +835,7 @@ class DualLinker(Linker):
function.
"""
def __init__(self, env, checker = _default_checker):
def __init__(self, env, checker = _default_checker, no_recycling = []):
"""
Initialize a DualLinker.
......@@ -844,6 +856,7 @@ class DualLinker(Linker):
"""
self.env = env
self.checker = checker
self.no_recycling = no_recycling
def make_thunk(self, **kwargs):
# if inplace:
......@@ -865,8 +878,10 @@ class DualLinker(Linker):
# thunks2 = [c_make_thunk(op) for op in op_order_2]
env = self.env
_f, i1, o1, thunks1, order1 = PerformLinker(env).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env).make_all(**kwargs)
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs)
def f():
for input1, input2 in zip(i1, i2):
......@@ -874,6 +889,12 @@ class DualLinker(Linker):
# the copy is necessary in order for inplace ops not to interfere
input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
for output, storage in zip(node1.outputs, thunk1.outputs):
if output in no_recycling:
storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs):
if output in no_recycling:
storage[0] = None
try:
thunk1()
thunk2()
......
......@@ -26,15 +26,7 @@ class Env(object): #(graph.Graph):
The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in omega.
Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are
both results that are the source nodes of computation. Those
results that are named as inputs will be assumed to contain fresh.
In other words, the backward search from outputs will stop at any
node that has been explicitly named as an input.
* x).out. This is the basis for optimization in theano.
"""
### Special ###
......@@ -68,10 +60,6 @@ class Env(object): #(graph.Graph):
self.node_locks = {}
self.result_locks = {}
# # List of functions that undo the replace operations performed.
# # e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
# self.history = []
### Setup a Result ###
......@@ -237,99 +225,13 @@ class Env(object): #(graph.Graph):
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
assert r in self.results
for node, i in r.clients:
for node, i in list(r.clients):
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
self.change_input(node, i, new_r)
# # Save where we are so we can backtrack
# if consistency_check:
# chk = self.checkpoint()
# # The copy is required so undo can know what clients to move back!
# clients = copy(self.clients(r))
# # Messy checks so we know what to do if we are replacing an output
# # result. Note that if v is an input result, we do nothing at all for
# # now (it's not clear what it means to replace an input result).
# was_output = False
# if r in self.outputs:
# was_output = True
# self.outputs[self.outputs.index(r)] = new_r
# was_input = False
# if r in self.inputs:
# was_input = True
# self.inputs[self.inputs.index(r)] = new_r
# # The actual replacement operation occurs here. This might raise
# # an error.
# self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
# # This function undoes the replacement.
# def undo():
# # Restore self.outputs
# if was_output:
# self.outputs[self.outputs.index(new_r)] = r
# # Restore self.inputs
# if was_input:
# self.inputs[self.inputs.index(new_r)] = r
# # Move back the clients. This should never raise an error.
# self.__move_clients__(clients, new_r, r)
# self.history.append(undo)
# if consistency_check:
# try:
# self.validate()
# except InconsistencyError, e:
# self.revert(chk)
# raise
def replace_all(self, d):
"""
For (r, new_r) in d.items(), replaces r with new_r. Checks for
consistency at the end and raises an InconsistencyError if the
graph is not consistent. If an error is raised, the graph is
restored to what it was before.
"""
for r, new_r in d.items():
self.replace(r, new_r, False)
# chk = self.checkpoint()
# try:
# for r, new_r in d.items():
# self.replace(r, new_r, False)
# except Exception, e:
# self.revert(chk)
# raise
# try:
# self.validate()
# except InconsistencyError, e:
# self.revert(chk)
# raise
# def checkpoint(self):
# """
# Returns an object that can be passed to self.revert in order to backtrack
# to a previous state.
# """
# return len(self.history)
# def consistent(self):
# """
# Returns True iff the subgraph is consistent and does not violate the
# constraints set by the listeners.
# """
# try:
# self.validate()
# except InconsistencyError:
# return False
# return True
def replace_all(self, pairs):
for r, new_r in pairs:
self.replace(r, new_r)
### features ###
......@@ -385,6 +287,16 @@ class Env(object): #(graph.Graph):
### misc ###
def toposort(self):
env = self
ords = {}
for feature in env._features:
if hasattr(feature, 'orderings'):
for op, prereqs in feature.orderings(env).items():
ords.setdefault(op, set()).update(prereqs)
order = graph.io_toposort(env.inputs, env.outputs, ords)
return order
def nclients(self, r):
"Same as len(self.clients(r))."
......@@ -438,118 +350,10 @@ class Env(object): #(graph.Graph):
raise Exception("Client not in env.", result, (node, i))
if node.inputs[i] is not result:
raise Exception("Inconsistent clients list.", result, node.inputs[i])
# def revert(self, checkpoint):
# """
# Reverts the graph to whatever it was at the provided
# checkpoint (undoes all replacements). A checkpoint at any
# given time can be obtained using self.checkpoint().
# """
# while len(self.history) > checkpoint:
# f = self.history.pop()
# f()
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
# def validate(self):
# """
# Raises an error if the graph is inconsistent.
# """
# self.execute_callbacks('validate')
# # for constraint in self._constraints.values():
# # constraint.validate()
# return True
### Private interface ###
# def __move_clients__(self, clients, r, new_r):
# if not (r.type == new_r.type):
# raise TypeError("Cannot move clients between Results that have different types.", r, new_r)
# # We import the new result in the fold
# self.__import_r__([new_r])
# for op, i in clients:
# op.inputs[i] = new_r
# # try:
# # # Try replacing the inputs
# # for op, i in clients:
# # op.set_input(i, new_r)
# # except:
# # # Oops!
# # for op, i in clients:
# # op.set_input(i, r)
# # self.__prune_r__([new_r])
# # raise
# self.__remove_clients__(r, clients)
# self.__add_clients__(new_r, clients)
# # # We import the new result in the fold
# # # why was this line AFTER the set_inputs???
# # # if we do it here then satisfy in import fucks up...
# # self.__import_r__([new_r])
# self.execute_callbacks('on_rewire', clients, r, new_r)
# # for listener in self._listeners.values():
# # try:
# # listener.on_rewire(clients, r, new_r)
# # except AbstractFunctionError:
# # pass
# # We try to get rid of the old one
# self.__prune_r__([r])
def __str__(self):
return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
# def clone_get_equiv(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# return new, equiv
# def clone(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# try:
# new.set_equiv(equiv)
# except AttributeError:
# pass
# return new
# def __copy__(self):
# return self.clone()
......
#from features import Listener, Constraint, Orderings, Tool
import graph
import utils
from utils import AbstractFunctionError
......@@ -253,7 +256,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
self.seen.add(op)
op.deps['destroy'] = []
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
......@@ -334,7 +336,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
del self.children[output]
self.seen.remove(op)
del op.deps['destroy']
def __add_destroyer__(self, path):
......@@ -350,9 +351,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(node, path)
print "add", path
node.deps['destroy'] += [user.owner for user in self.__users__(foundation) if user not in node.outputs]
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
......@@ -361,7 +359,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
self.dups.add(foundation)
# results marked 'indestructible' must not be destroyed.
if getattr(foundation, 'indestructible', False):
if getattr(foundation, 'indestructible', False) or isinstance(foundation, graph.Constant):
self.illegal.add(foundation)
......@@ -374,13 +372,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
target = path[-1]
node = target.owner
print "rm", path
print node.deps['destroy']
for user in self.__users__(foundation):
print " -- ", user
if user not in node.outputs:
node.deps['destroy'].remove(user.owner)
destroyers = self.destroyers[foundation]
del destroyers[node]
......@@ -477,6 +468,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result.
"""
self.validate(env)
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
......
......@@ -163,7 +163,6 @@ def as_apply(x):
@deprecated
def inputs(o):
"""
......@@ -173,7 +172,6 @@ def inputs(o):
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
print 'gof.graph.inputs deprecated: April 29'
results = set()
def seek(r):
op = r.owner
......@@ -187,53 +185,71 @@ def inputs(o):
return results
def results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
# def results_and_orphans(i, o, except_unreachable_input=False):
# """
# @type i: list
# @param i: input L{Result}s
# @type o: list
# @param o: output L{Result}s
# Returns the pair (results, orphans). The former is the set of
# L{Result}s that are involved in the subgraph that lies between i and
# o. This includes i, o, orphans(i, o) and all results of all
# intermediary steps from i to o. The second element of the returned
# pair is orphans(i, o).
# """
# results = set()
# i = set(i)
# # results.update(i)
# incomplete_paths = []
# reached = set()
# def helper(r, path):
# if r in i:
# reached.add(r)
# results.update(path)
# elif r.owner is None:
# incomplete_paths.append(path)
# else:
# op = r.owner
# for r2 in op.inputs:
# helper(r2, path + [r2])
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
# results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
# for output in o:
# helper(output, [output])
for output in o:
helper(output, [output])
# orphans = set()
# for path in incomplete_paths:
# for r in path:
# if r not in results:
# orphans.add(r)
# break
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
# if except_unreachable_input and len(i) != len(reached):
# raise Exception(results_and_orphans.E_unreached)
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
# results.update(orphans)
results.update(orphans)
# return results, orphans
# results_and_orphans.E_unreached = 'there were unreachable inputs'
def results_and_orphans(i, o):
results = set()
orphans = set()
def helper(r):
if r in results:
return
results.add(r)
if r.owner is None:
if r not in i:
orphans.add(r)
else:
for r2 in r.owner.inputs:
helper(r2)
for output in o:
helper(output)
return results, orphans
results_and_orphans.E_unreached = 'there were unreachable inputs'
def ops(i, o):
......
......@@ -2,7 +2,7 @@
from utils import AbstractFunctionError
import utils
from graph import Constant
from graph import Value
import sys
import traceback
......@@ -135,16 +135,20 @@ def map_storage(env, order, input_storage, output_storage):
storage_map = {}
for r, storage in zip(env.inputs, input_storage):
storage_map[r] = storage
for orphan in env.orphans:
if not isinstance(orphan, Constant):
raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
storage_map[orphan] = [orphan.data]
# for orphan in env.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
if output_storage is not None:
assert len(env.outputs) == len(output_storage)
for r, storage in zip(env.outputs, output_storage):
storage_map[r] = storage
thunks = []
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, Value)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
......
......@@ -430,11 +430,16 @@ class MergeOptimizer(Optimizer):
are constant.
"""
def add_requirements(self, env):
try:
env.extend(toolbox.ReplaceValidate())
except: pass
def apply(self, env):
cid = _metadict() #result -> result.desc() (for constants)
inv_cid = _metadict() #desc -> result (for constants)
for i, r in enumerate(env.orphans.union(env.inputs)):
if isinstance(r, Constant):
for i, r in enumerate([r for r in env.results if isinstance(r, Constant)]): #env.orphans.union(env.inputs)):
#if isinstance(r, Constant):
sig = r.signature()
other_r = inv_cid.get(sig, None)
if other_r is not None:
......@@ -446,20 +451,19 @@ class MergeOptimizer(Optimizer):
# and it's more efficient to give them an integer cid like the other Results
cid.clear()
inv_cid.clear()
for i, r in enumerate(env.orphans.union(env.inputs)):
for i, r in enumerate(r for r in env.results if r.owner is None):
cid[r] = i
inv_cid[i] = r
for node in env.io_toposort():
for node in graph.io_toposort(env.inputs, env.outputs):
node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None)
success = False
if dup is not None:
success = True
d = dict(zip(node.outputs, dup.outputs))
try:
env.replace_all(d)
except Exception, e:
env.replace_all_validate(zip(node.outputs, dup.outputs))
except InconsistencyError, e:
success = False
if not success:
cid[node] = node_cid
......
......@@ -16,6 +16,51 @@ class Bookkeeper:
self.on_prune(env, node)
class Toposorter:
def on_attach(self, env):
if hasattr(env, 'toposort'):
raise Exception("Toposorter feature is already present or in conflict with another plugin.")
env.toposort = partial(self.toposort, env)
def on_deattach(self, env):
del env.toposort
def toposort(self, env):
ords = {}
for feature in env._features:
if hasattr(feature, 'orderings'):
for op, prereqs in feature.orderings(env).items():
ords.setdefault(op, set()).update(prereqs)
order = graph.io_toposort(env.inputs, env.outputs, ords)
return order
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
class History:
def __init__(self):
......
......@@ -25,10 +25,6 @@ def as_scalar(x, name = None):
if not isinstance(x.type, Scalar):
raise TypeError("Result type field must be a Scalar.", x, x.type)
return x
if isinstance(x, Constant):
if not isinstance(x.type, Scalar):
raise TypeError("Constant type field must be a Scalar.", x, x.type)
return x
try:
return constant(x)
except TypeError:
......@@ -582,7 +578,7 @@ tanh = Tanh(upgrade_to_float, name = 'tanh')
class Composite(ScalarOp):
def __init__(self, inputs, outputs):
env = Env(inputs, outputs).clone()
env = Env(*gof.graph.clone(inputs, outputs))
inputs, outputs = env.inputs, env.outputs
for node in env.nodes:
......@@ -594,11 +590,12 @@ class Composite(ScalarOp):
zip(outputs,
["%%(o%i)s"%i for i in range(len(outputs))]))
for orphan in env.orphans:
if isinstance(orphan, Constant):
subd[orphan] = orphan.type.c_literal(orphan.data)
else:
raise ValueError("All orphans in the env to Composite must be Constant instances.")
for orphan in env.results: #env.orphans:
if orphan.owner is None and orphan not in env.inputs:
if isinstance(orphan, Constant):
subd[orphan] = orphan.type.c_literal(orphan.data)
else:
raise ValueError("All orphans in the env to Composite must be Constant instances.")
_c_code = "{\n"
i = 0
......@@ -611,7 +608,7 @@ class Composite(ScalarOp):
name = "V%%(id)s_tmp%i" % i
subd[output] = name
_c_code += "%s %s;\n" % (output.type.dtype_specs()[1], name)
_c_code += node.op.c_code(node.inputs,
_c_code += node.op.c_code(node,
"%(name)s",
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
......@@ -629,7 +626,7 @@ class Composite(ScalarOp):
if r in env.inputs:
idx = env.inputs.index(r)
return lambda inputs: inputs[idx]
elif r in env.orphans:
elif r.owner is None: # in env.orphans:
return lambda inputs: r.data
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
......
差异被折叠。
......@@ -6,7 +6,7 @@ import numpy
from copy import copy
from gof import Result, Op, utils, Destroyer, Viewer, AbstractFunctionError, Type, Result, Constant, Apply
from gof import Result, Op, utils, Destroyer, Viewer, AbstractFunctionError, Type, Result, Constant, Apply, Value
import gof
import blas # for gemm, dot
......@@ -27,14 +27,9 @@ def as_tensor(x, name = None):
if not isinstance(x.type, Tensor):
raise TypeError("Result type field must be a Tensor.", x, x.type)
return x
if isinstance(x, Constant):
if not isinstance(x.type, Tensor):
raise TypeError("Constant type field must be a Tensor.", x, x.type)
return x
try:
return constant(x)
except TypeError:
raise
raise TypeError("Cannot convert %s to Tensor" % x, type(x))
# this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
......@@ -48,9 +43,18 @@ def constant(x):
return TensorConstant(Tensor(dtype = x.dtype,
broadcastable = [d == 1 for d in x.shape]), x)
except:
raise
raise TypeError("Could not convert %s to Tensor" % _x, type(_x))
def value(x):
if not isinstance(x, numpy.ndarray):
x = numpy.asarray(x)
try:
return TensorValue(Tensor(dtype = x.dtype,
broadcastable = [d == 1 for d in x.shape]), x)
except:
raise TypeError("Could not convert %s to Tensor" % _x, type(_x))
class Tensor(Type):
"""
......@@ -342,10 +346,14 @@ class TensorResult(Result, _tensor_py_operators):
class TensorConstant(Constant, _tensor_py_operators):
pass
class TensorValue(Value, _tensor_py_operators):
pass
s2t.as_tensor = as_tensor
s2t.Tensor = Tensor
s2t.TensorResult = TensorResult
s2t.TensorConstant = TensorConstant
s2t.TensorValue = TensorValue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论