提交 a8b455fc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

corrected env bug, activated constant folding and subgraph merging, added Op.desc and Result.desc

上级 c8c2c4f7
...@@ -97,9 +97,9 @@ class BaseTensor(ResultBase): ...@@ -97,9 +97,9 @@ class BaseTensor(ResultBase):
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
# #
# Hash for constant folding # Description for constant folding
# #
def hash(self): def desc(self):
if self.data is not None: if self.data is not None:
return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:]) return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:])
else: else:
......
...@@ -16,8 +16,8 @@ exec_opt.optimizer = None ...@@ -16,8 +16,8 @@ exec_opt.optimizer = None
def default_optimizer(env): def default_optimizer(env):
#TODO: pass tests with these un-commented #TODO: pass tests with these un-commented
# default_optimizer.const(env) default_optimizer.const(env)
# default_optimizer.merge(env) default_optimizer.merge(env)
pass pass
default_optimizer.merge = gof.opt.MergeOptimizer() default_optimizer.merge = gof.opt.MergeOptimizer()
default_optimizer.const = gof.opt.ConstantFinder() default_optimizer.const = gof.opt.ConstantFinder()
......
...@@ -7,7 +7,7 @@ from base_tensor import BaseTensor as Tensor ...@@ -7,7 +7,7 @@ from base_tensor import BaseTensor as Tensor
from scalar import upcast, Scalar from scalar import upcast, Scalar
import scalar_ops import scalar_ops
import gof import gof
from gof.python25 import all
def astensor(data): def astensor(data):
assert isinstance(data, Tensor) assert isinstance(data, Tensor)
...@@ -68,6 +68,9 @@ class DimShuffle(Op, Viewer): ...@@ -68,6 +68,9 @@ class DimShuffle(Op, Viewer):
else: else:
return {} return {}
def desc(self):
return (self.__class__, tuple(self.new_order))
def perform(self): def perform(self):
res = self.inputs[0].data res = self.inputs[0].data
shape = list(res.shape) shape = list(res.shape)
...@@ -153,8 +156,8 @@ class Broadcast(Op, Destroyer): ...@@ -153,8 +156,8 @@ class Broadcast(Op, Destroyer):
def clone_with_new_inputs(self, *new_inputs): def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern) return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
def id(self): def desc(self):
return (self.__class__, self.scalar_opclass, self.inplace_pattern) return (self.__class__, self.scalar_opclass, tuple(self.inplace_pattern.items()))
def destroy_map(self): def destroy_map(self):
ret = {} ret = {}
...@@ -388,8 +391,8 @@ class CAReduce(Op): ...@@ -388,8 +391,8 @@ class CAReduce(Op):
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)]) self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)])
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout) self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def id(self): def desc(self):
return (self.__class__, self.scalar_opclass, self.dimensions_to_reduce) return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce))
def clone_with_new_inputs(self, *new_inputs): def clone_with_new_inputs(self, *new_inputs):
return CAReduce(self.scalar_opclass, new_inputs, self.dimensions_to_reduce) return CAReduce(self.scalar_opclass, new_inputs, self.dimensions_to_reduce)
......
...@@ -21,7 +21,7 @@ class MyResult(ResultBase): ...@@ -21,7 +21,7 @@ class MyResult(ResultBase):
def __repr__(self): def __repr__(self):
return self.name return self.name
def hash(self): def desc(self):
return self.data return self.data
......
...@@ -400,7 +400,7 @@ class Env(graph.Graph): ...@@ -400,7 +400,7 @@ class Env(graph.Graph):
# in new ops, so we use all results we know of as if they were the input set. # in new ops, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
# know where to stop going down) # know where to stop going down)
new_ops = graph.io_toposort(self.results(), op.outputs) new_ops = graph.io_toposort(self.results().difference(self.orphans()), op.outputs)
for op in new_ops: for op in new_ops:
...@@ -443,6 +443,8 @@ class Env(graph.Graph): ...@@ -443,6 +443,8 @@ class Env(graph.Graph):
# Cannot prune an op which is an output or used somewhere # Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output): if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return return
if op not in self._ops: # this can happen from replacing an orphan
return
self._ops.remove(op) self._ops.remove(op)
self._results.difference_update(op.outputs) self._results.difference_update(op.outputs)
for listener in self._listeners.values(): for listener in self._listeners.values():
......
...@@ -217,6 +217,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -217,6 +217,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.seen.add(op) self.seen.add(op)
view_map, destroy_map = self.get_maps(op) view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
self.children.setdefault(input, set())
for i, output in enumerate(op.outputs): for i, output in enumerate(op.outputs):
views = view_map.get(output, None) views = view_map.get(output, None)
destroyed = destroy_map.get(output, None) destroyed = destroy_map.get(output, None)
......
...@@ -73,6 +73,9 @@ class Op(object): ...@@ -73,6 +73,9 @@ class Op(object):
self._hash_id = utils.hashgen() self._hash_id = utils.hashgen()
return self._hash_id return self._hash_id
def desc(self):
return self.__class__
# #
# #
# #
......
...@@ -310,7 +310,7 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -310,7 +310,7 @@ class PatternOptimizer(OpSpecificOptimizer):
and getattr(pattern, 'constant', False) \ and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \ and isinstance(expr, ResultBase) \
and getattr(expr, 'constant', False) \ and getattr(expr, 'constant', False) \
and pattern.hash() == expr.hash(): and pattern.desc() == expr.desc():
return u return u
else: else:
return False return False
...@@ -371,7 +371,7 @@ class ConstantFinder(Optimizer): ...@@ -371,7 +371,7 @@ class ConstantFinder(Optimizer):
for r in env.inputs: for r in env.inputs:
r.indestructible = True r.indestructible = True
import graph
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """
Merges parts of the graph that are identical, i.e. parts that Merges parts of the graph that are identical, i.e. parts that
...@@ -381,11 +381,11 @@ class MergeOptimizer(Optimizer): ...@@ -381,11 +381,11 @@ class MergeOptimizer(Optimizer):
""" """
def apply(self, env): def apply(self, env):
cid = {} #result -> result.hash() (for constants) cid = {} #result -> result.desc() (for constants)
inv_cid = {} #hash -> result (for constants) inv_cid = {} #desc -> result (for constants)
for i, r in enumerate(env.orphans().union(env.inputs)): for i, r in enumerate(env.orphans().union(env.inputs)):
if getattr(r, 'constant', False) and hasattr(r, 'hash'): if getattr(r, 'constant', False):
ref = ('const', r.hash()) ref = ('const', r.desc())
other_r = inv_cid.get(ref, None) other_r = inv_cid.get(ref, None)
if other_r is not None: if other_r is not None:
env.replace(r, other_r) env.replace(r, other_r)
...@@ -397,20 +397,16 @@ class MergeOptimizer(Optimizer): ...@@ -397,20 +397,16 @@ class MergeOptimizer(Optimizer):
inv_cid[i] = r inv_cid[i] = r
for op in env.io_toposort(): for op in env.io_toposort():
# this could be made more robust by having an op.hash() that op_cid = (op.desc(), tuple([cid[input] for input in op.inputs]))
# doesn't depend on the inputs but can depend on additional properties
# of the op.
op_cid = (op.__class__, tuple([cid[input] for input in op.inputs]))
dup = inv_cid.get(op_cid, None) dup = inv_cid.get(op_cid, None)
success = False success = False
if dup is not None: if dup is not None:
success = True success = True
for output, other_output in zip(op.outputs, dup.outputs): d = dict(zip(op.outputs, dup.outputs))
try: try:
env.replace(output, other_output) env.replace_all(d)
except: except Exception, e:
success = False success = False
break
if not success: if not success:
cid[op] = op_cid cid[op] = op_cid
inv_cid[op_cid] = op inv_cid[op_cid] = op
......
...@@ -78,6 +78,9 @@ class ResultBase(object): ...@@ -78,6 +78,9 @@ class ResultBase(object):
def __hash__(self): def __hash__(self):
return self._hash_id return self._hash_id
def desc(self):
return id(self)
# #
# role # role
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论