提交 65e08101 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

env redone, toolbox redone

上级 646f4c01
......@@ -2,13 +2,15 @@
import unittest
from type import Type
import graph
from graph import Result, as_result, Apply
from op import Op
from opt import PatternOptimizer, OpSubOptimizer
from ext import *
from env import Env, InconsistencyError
from toolbox import EquivTool
#from toolbox import EquivTool
from toolbox import ReplaceValidate
from copy import copy
......@@ -65,8 +67,11 @@ def inputs():
_Env = Env
def Env(inputs, outputs, validate = True):
e = _Env(inputs, outputs)
e.extend(EquivTool(e))
e.extend(DestroyHandler(e), validate = validate)
##e.extend(EquivTool(e))
e.extend(DestroyHandler())
e.extend(ReplaceValidate())
if validate:
e.validate()
return e
......@@ -108,19 +113,19 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e1, e2])
chk = g.checkpoint()
assert g.consistent()
g.replace(e1, add_in_place(x, y))
g.replace_validate(e1, add_in_place(x, y))
assert g.consistent()
try:
g.replace(e2, add_in_place(y, x))
g.replace_validate(e2, add_in_place(y, x))
self.fail()
except InconsistencyError:
pass
assert g.consistent()
g.revert(chk)
g.replace(e2, add_in_place(y, x))
g.replace_validate(e2, add_in_place(y, x))
assert g.consistent()
try:
g.replace(e1, add_in_place(x, y))
g.replace_validate(e1, add_in_place(x, y))
self.fail()
except InconsistencyError:
pass
......@@ -136,7 +141,7 @@ class _test_all(unittest.TestCase):
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x))
try:
g2 = Env([x,y,z], [e2])
g2 = Env(*graph.clone([x,y,z], [e2]))
self.fail()
except InconsistencyError:
pass
......@@ -154,16 +159,18 @@ class _test_all(unittest.TestCase):
e = dot(aip, transpose_view(x))
g = Env([x,y,z], [e], False)
assert not g.consistent()
g.replace(aip, add(x, z))
g.replace_validate(aip, add(x, z))
assert g.consistent()
def test_usage_loop_through_views_2(self):
x, y, z = inputs()
e0 = transpose_view(transpose_view(transpose_view(sigmoid(x))))
e0 = transpose_view(transpose_view(sigmoid(x)))
e = dot(add_in_place(x,y), transpose_view(e0))
g = Env([x,y,z], [e])
assert g.consistent() # because sigmoid can do the copy
g.replace(e0, x, False)
# print g
# print g.destroy_handler.children
g.replace(e0, x)
assert not g.consistent() # we cut off the path to the sigmoid
def test_usage_loop_insert_views(self):
......@@ -184,10 +191,10 @@ class _test_all(unittest.TestCase):
chk = g.checkpoint()
PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]"
g.replace(g.equiv(e), add(x,y))
print g
new_e = add(x,y)
g.replace_validate(x, new_e)
assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False)
g.replace(new_e, dot(add_in_place(x,y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent()
g.revert(chk)
......@@ -202,7 +209,7 @@ class _test_all(unittest.TestCase):
e = add_in_place(x, y)
g = Env([x,y,z], [e], False)
assert not g.consistent()
g.replace(e, add(x, y))
g.replace_validate(e, add(x, y))
assert g.consistent()
def test_indestructible_through_views(self):
......@@ -212,7 +219,7 @@ class _test_all(unittest.TestCase):
e = add_in_place(tv, y)
g = Env([x,y,z], [e], False)
assert not g.consistent()
g.replace(tv, sigmoid(x))
g.replace_validate(tv, sigmoid(x))
assert g.consistent()
def test_repair_destroy_path(self):
......@@ -223,7 +230,7 @@ class _test_all(unittest.TestCase):
e4 = add_in_place(e1, z)
g = Env([x,y,z], [e3, e4], False)
assert not g.consistent()
g.replace(e2, transpose_view(x), False)
g.replace(e2, transpose_view(x))
assert not g.consistent()
def test_indirect(self):
......@@ -233,9 +240,9 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e], False)
assert not g.consistent()
new_e0 = add(x, y)
g.replace(e0, new_e0, False)
g.replace(e0, new_e0)
assert g.consistent()
g.replace(new_e0, add_in_place(x, y), False)
g.replace(new_e0, add_in_place(x, y))
assert not g.consistent()
def test_indirect_2(self):
......@@ -245,12 +252,12 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e], False)
assert not g.consistent()
new_e0 = add(e0, y)
g.replace(e0, new_e0, False)
g.replace(e0, new_e0)
assert g.consistent()
if __name__ == '__main__':
unittest.main()
#unittest.main()
_test_all('test_usage_loop_through_views').debug()
......@@ -59,19 +59,19 @@ def inputs():
return x, y, z
class _test_EquivTool(unittest.TestCase):
# class _test_EquivTool(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
sx = sigmoid(x)
e = add(sx, sigmoid(y))
g = Env([x, y, z], [e])
g.extend(EquivTool(g))
assert hasattr(g, 'equiv')
assert g.equiv(sx) is sx
g.replace(sx, dot(x, z))
assert g.equiv(sx) is not sx
assert g.equiv(sx).owner.op is dot
# def test_straightforward(self):
# x, y, z = inputs()
# sx = sigmoid(x)
# e = add(sx, sigmoid(y))
# g = Env([x, y, z], [e])
# g.extend(EquivTool(g))
# assert hasattr(g, 'equiv')
# assert g.equiv(sx) is sx
# g.replace(sx, dot(x, z))
# assert g.equiv(sx) is not sx
# assert g.equiv(sx).owner.op is dot
class _test_NodeFinder(unittest.TestCase):
......@@ -81,7 +81,7 @@ class _test_NodeFinder(unittest.TestCase):
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e])
g.extend(NodeFinder(g))
g.extend(NodeFinder())
assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num:
......@@ -100,7 +100,7 @@ class _test_NodeFinder(unittest.TestCase):
x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = Env([x, y, z], [e])
g.extend(NodeFinder(g))
g.extend(NodeFinder())
gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them
......
差异被折叠。
from features import Listener, Constraint, Orderings, Tool
#from features import Listener, Constraint, Orderings, Tool
import utils
from utils import AbstractFunctionError
from copy import copy
from env import InconsistencyError
from toolbox import Bookkeeper
class DestroyHandler(Listener, Constraint, Orderings, Tool):
from collections import defaultdict
class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
......@@ -27,14 +35,32 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
This feature allows some optimizations (eg sub += for +) to be applied
safely.
"""
def __init__(self, env):
def __init__(self):
self.env = None
def on_attach(self, env):
if self.env is not None:
raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr):
raise Exception("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
return ret
env.destroyers = __destroyers
env.destroy_handler = self
self.env = env
# For an Op that has a view_map, {output : input it is a view of}
self.parent = {}
# Reverse mapping of parent: {input : outputs that are a view of it}
self.children = {}
self.children = defaultdict(set)
# {foundation : {op that destroys it : path }}
# where foundation is a result such that (not self.parent[result])
......@@ -57,25 +83,37 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
# indestructible by the user.
self.illegal = set()
self.env = env
self.seen = set()
# Initialize the children if the inputs and orphans.
for input in env.orphans.union(env.inputs):
self.children[input] = set()
def publish(self):
"""
Publishes the following on the env:
- destroyers(r) -> returns all L{Op}s that destroy the result r
- destroy_handler -> self
"""
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
return ret
self.env.destroyers = __destroyers
self.env.destroy_handler = self
Bookkeeper.on_attach(self, env)
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
# self.children[input] = set()
def on_detach(self, env):
del self.parent
del self.children
del self.destroyers
del self.paths
del self.dups
del self.cycles
del self.illegal
del self.seen
self.env = None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def __path__(self, r):
"""
......@@ -105,12 +143,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
"""
children = self.children[r]
if not children:
return set([r])
return [r]
else:
rval = set([r])
rval = [r]
for child in children:
rval.update(self.__views__(child))
return rval
rval += self.__views__(child)
return utils.uniq(rval)
def __users__(self, r):
"""
......@@ -120,12 +158,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
is returned.
"""
views = self.__views__(r)
rval = set()
rval = [] # set()
for view in views:
for op, i in self.env.clients(view):
if op in self.seen:
rval.update(op.outputs)
return rval
for node, i in view.clients: #self.env.clients(view):
if node != 'output':
rval += node.outputs
return utils.uniq(rval)
def __pre__(self, op):
"""
......@@ -178,7 +216,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
just_remove is True, we return immediately after removing the
cycles.
"""
users = self.__users__(start)
users = set(self.__users__(start))
users.add(start)
for user in users:
for cycle in copy(self.cycles):
......@@ -208,13 +246,14 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
return vmap, dmap
def on_import(self, op):
def on_import(self, env, op):
"""
Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env.
"""
self.seen.add(op)
op.deps['destroy'] = []
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
......@@ -251,7 +290,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.__detect_cycles_helper__(output, [])
def on_prune(self, op):
def on_prune(self, env, op):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env.
......@@ -295,6 +334,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
del self.children[output]
self.seen.remove(op)
del op.deps['destroy']
def __add_destroyer__(self, path):
......@@ -305,11 +345,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation = path[0]
target = path[-1]
op = target.owner
node = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
path = destroyers.setdefault(node, path)
print "add", path
node.deps['destroy'] += [user.owner for user in self.__users__(foundation) if user not in node.outputs]
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
if len(destroyers) > 1:
self.dups.add(foundation)
......@@ -325,10 +372,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation = path[0]
target = path[-1]
op = target.owner
node = target.owner
print "rm", path
print node.deps['destroy']
for user in self.__users__(foundation):
print " -- ", user
if user not in node.outputs:
node.deps['destroy'].remove(user.owner)
destroyers = self.destroyers[foundation]
del destroyers[op]
del destroyers[node]
if not destroyers:
if foundation in self.illegal:
......@@ -338,14 +392,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
def on_change_input(self, env, node, i, r, new_r):
if node != 'output':
self.on_rewire(env, [(node, i)], r, new_r)
def on_rewire(self, env, clients, r_1, r_2):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
now r_2.
"""
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
......@@ -396,7 +454,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.children.setdefault(r_2, set())
self.__detect_cycles__(r_2)
def validate(self):
def validate(self, env):
"""
Raises an L{InconsistencyError} on any of the following conditions:
- Some results are destroyed by more than one L{Op}
......@@ -412,9 +470,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
else:
return True
def orderings(self):
def orderings(self, env):
"""
Returns a dict of {op : set(ops that must be computed before it)} according
Returns a dict of {node : set(nodes that must be computed before it)} according
to L{DestroyHandler}.
In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result.
......@@ -426,6 +484,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return ords
class Destroyer:
"""
Base class for Ops that destroy one or more of their inputs in an
......@@ -493,3 +553,4 @@ def view_roots(r):
return [r]
else:
return [r]
......@@ -202,7 +202,7 @@ def results_and_orphans(i, o, except_unreachable_input=False):
"""
results = set()
i = set(i)
results.update(i)
# results.update(i)
incomplete_paths = []
reached = set()
......@@ -287,7 +287,7 @@ def orphans(i, o):
return results_and_orphans(i, o)[1]
def clone(i, o, copy_inputs = False):
def clone(i, o, copy_inputs = True):
"""
@type i: list
@param i: input L{Result}s
......@@ -299,8 +299,8 @@ def clone(i, o, copy_inputs = False):
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o).
"""
equiv = clone_get_equiv(i, o)
return [equiv[output] for output in o]
equiv = clone_get_equiv(i, o, copy_inputs)
return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
......@@ -324,7 +324,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
for input in i:
if copy_inputs_and_orphans:
cpy = copy(input)
cpy = input.clone()
cpy.owner = None
cpy.index = None
d[input] = cpy
......@@ -337,7 +337,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
node = result.owner
if node is None: # result is an orphan
if copy_inputs_and_orphans:
cpy = copy(result)
cpy = result.clone()
d[result] = cpy
else:
d[result] = result
......
......@@ -115,7 +115,10 @@ class OpSpecificOptimizer(LocalOptimizer):
"""
def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def candidates(self, env):
"""
......@@ -135,7 +138,10 @@ class OpSubOptimizer(Optimizer):
"""
def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op1, op2, failure_callback = None):
"""
......@@ -163,7 +169,7 @@ class OpSubOptimizer(Optimizer):
repl = self.op2.make_node(*node.inputs)
assert len(node.outputs) == len(repl.outputs)
for old, new in zip(node.outputs, repl.outputs):
env.replace(old, new)
env.replace_validate(old, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(node, repl, e)
......@@ -182,7 +188,10 @@ class OpRemover(Optimizer):
"""
def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op, failure_callback = None):
"""
......
from random import shuffle
import utils
from functools import partial
import graph
class EquivTool(dict):
class Bookkeeper:
def on_attach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_import(env, node)
def __init__(self, env):
self.env = env
def on_deattach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node)
class History:
def __init__(self):
self.history = {}
def on_rewire(self, clients, r, new_r):
repl = self(new_r)
if repl is r:
self.ungroup(r, new_r)
elif repl is not new_r:
raise Exception("Improper use of EquivTool!")
else:
self.group(new_r, r)
def publish(self):
self.env.equiv = self
self.env.set_equiv = self.set_equiv
def unpublish(self):
del self.env.equiv
del self.env.set_equiv
def set_equiv(self, d):
self.update(d)
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class NodeFinder(dict):
def __init__(self, env):
def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise Exception("History feature is already present or in conflict with another plugin.")
self.history[env] = []
env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env)
def on_deattach(self, env):
del env.checkpoint
del env.revert
del self.history[env]
def on_change_input(self, env, node, i, r, new_r):
if self.history[env] is None:
return
h = self.history[env]
h.append(lambda: env.change_input(node, i, r))
def revert(self, env, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
h = self.history[env]
self.history[env] = None
while len(h) > checkpoint:
f = h.pop()
f()
self.history[env] = h
class Validator:
def on_attach(self, env):
if hasattr(env, 'validate'):
raise Exception("Validator feature is already present or in conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate')
def consistent():
try:
env.validate()
return True
except:
return False
env.consistent = consistent
def on_deattach(self, env):
del env.validate
del env.consistent
class ReplaceValidate(History, Validator):
def on_attach(self, env):
History.on_attach(self, env)
Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr):
raise Exception("ReplaceValidate feature is already present or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
def on_deattach(self, env):
History.on_deattach(self, env)
Validator.on_deattach(self, env)
del env.replace_validate
del env.replace_all_validate
def replace_validate(self, env, r, new_r):
self.replace_all_validate(env, [(r, new_r)])
def replace_all_validate(self, env, replacements):
chk = env.checkpoint()
for r, new_r in replacements:
env.replace(r, new_r)
try:
env.validate()
except:
env.revert(chk)
raise
class NodeFinder(dict, Bookkeeper):
def __init__(self):
self.env = None
def on_attach(self, env):
if self.env is not None:
raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'):
raise Exception("NodeFinder is already present or in conflict with another plugin.")
self.env = env
env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env)
def on_import(self, node):
def on_deattach(self, env):
if self.env is not env:
raise Exception("This NodeFinder instance was not attached to the provided env.")
self.env = None
del env.get_nodes
Bookkeeper.on_deattach(self, env)
def on_import(self, env, node):
try:
self.setdefault(node.op, set()).add(node)
except TypeError:
pass
self.setdefault(node.op, []).append(node)
except TypeError: #node.op is unhashable
return
def on_prune(self, node):
def on_prune(self, env, node):
try:
self[node.op].remove(node)
except TypeError:
nodes = self[node.op]
except TypeError: #node.op is unhashable
return
if not self[node.op]:
nodes.remove(node)
if not nodes:
del self[node.op]
def query(self, op):
def query(self, env, op):
try:
all = self.get(op, [])
except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = [x for x in all]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
all = list(all)
while all:
next = all.pop()
if self.env.has_node(next):
if next in env.nodes:
yield next
def publish(self):
self.env.get_nodes = self.query
def __eq__(self, other):
return isinstance(other, NodeFinder) and self.env is other.env
class PrintListener(object):
def __init__(self, active = True):
self.active = active
def on_attach(self, env):
if self.active:
print "-- attaching to: ", env
def on_deattach(self, env):
if self.active:
print "-- deattaching from: ", env
def on_import(self, env, node):
if self.active:
print "-- importing: %s" % node
def on_prune(self, env, node):
if self.active:
print "-- pruning: %s" % node
def on_change_input(self, env, node, i, r, new_r):
if self.active:
print "-- changing (%s.inputs[%s]) from %s to %s" % (node, i, r, new_r)
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
......@@ -158,28 +302,6 @@ class NodeFinder(dict):
class PrintListener(object):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, node):
if self.active:
print "-- importing: %s" % node
def on_prune(self, node):
if self.active:
print "-- pruning: %s" % node
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is not None: r = r.owner
if new_r.owner is not None: new_r = new_r.owner
print "-- moving from %s to %s" % (r, new_r)
### UNUSED AND UNTESTED ###
......
......@@ -26,6 +26,8 @@ class object2(object):
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self)
return id(self)
def __ne__(self, other):
return not self == other
class scratchpad:
def clear(self):
......
......@@ -71,7 +71,7 @@ class Tensor(Type):
def __init__(self, dtype, broadcastable):
self.dtype = str(dtype)
self.broadcastable = broadcastable
self.broadcastable = tuple(broadcastable)
self.dtype_specs() # error checking is done there
def filter(self, data, strict = False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论