提交 a8b455fc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

corrected env bug, activated constant folding and subgraph merging, added Op.desc and Result.desc

上级 c8c2c4f7
......@@ -97,9 +97,9 @@ class BaseTensor(ResultBase):
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
#
# Hash for constant folding
# Description for constant folding
#
def hash(self):
def desc(self):
if self.data is not None:
return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:])
else:
......
......@@ -16,8 +16,8 @@ exec_opt.optimizer = None
def default_optimizer(env):
#TODO: pass tests with these un-commented
# default_optimizer.const(env)
# default_optimizer.merge(env)
default_optimizer.const(env)
default_optimizer.merge(env)
pass
default_optimizer.merge = gof.opt.MergeOptimizer()
default_optimizer.const = gof.opt.ConstantFinder()
......
......@@ -7,7 +7,7 @@ from base_tensor import BaseTensor as Tensor
from scalar import upcast, Scalar
import scalar_ops
import gof
from gof.python25 import all
def astensor(data):
assert isinstance(data, Tensor)
......@@ -68,6 +68,9 @@ class DimShuffle(Op, Viewer):
else:
return {}
def desc(self):
return (self.__class__, tuple(self.new_order))
def perform(self):
res = self.inputs[0].data
shape = list(res.shape)
......@@ -153,8 +156,8 @@ class Broadcast(Op, Destroyer):
def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
def id(self):
return (self.__class__, self.scalar_opclass, self.inplace_pattern)
def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.inplace_pattern.items()))
def destroy_map(self):
ret = {}
......@@ -388,8 +391,8 @@ class CAReduce(Op):
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)])
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def id(self):
return (self.__class__, self.scalar_opclass, self.dimensions_to_reduce)
def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce))
def clone_with_new_inputs(self, *new_inputs):
return CAReduce(self.scalar_opclass, new_inputs, self.dimensions_to_reduce)
......
......@@ -21,7 +21,7 @@ class MyResult(ResultBase):
def __repr__(self):
return self.name
def hash(self):
def desc(self):
return self.data
......
......@@ -400,7 +400,7 @@ class Env(graph.Graph):
# in new ops, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_ops = graph.io_toposort(self.results(), op.outputs)
new_ops = graph.io_toposort(self.results().difference(self.orphans()), op.outputs)
for op in new_ops:
......@@ -443,6 +443,8 @@ class Env(graph.Graph):
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
if op not in self._ops: # this can happen from replacing an orphan
return
self._ops.remove(op)
self._results.difference_update(op.outputs)
for listener in self._listeners.values():
......
......@@ -217,6 +217,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.seen.add(op)
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
self.children.setdefault(input, set())
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
......
......@@ -73,6 +73,9 @@ class Op(object):
self._hash_id = utils.hashgen()
return self._hash_id
def desc(self):
return self.__class__
#
#
#
......
......@@ -310,7 +310,7 @@ class PatternOptimizer(OpSpecificOptimizer):
and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \
and getattr(expr, 'constant', False) \
and pattern.hash() == expr.hash():
and pattern.desc() == expr.desc():
return u
else:
return False
......@@ -371,7 +371,7 @@ class ConstantFinder(Optimizer):
for r in env.inputs:
r.indestructible = True
import graph
class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical, i.e. parts that
......@@ -381,11 +381,11 @@ class MergeOptimizer(Optimizer):
"""
def apply(self, env):
cid = {} #result -> result.hash() (for constants)
inv_cid = {} #hash -> result (for constants)
cid = {} #result -> result.desc() (for constants)
inv_cid = {} #desc -> result (for constants)
for i, r in enumerate(env.orphans().union(env.inputs)):
if getattr(r, 'constant', False) and hasattr(r, 'hash'):
ref = ('const', r.hash())
if getattr(r, 'constant', False):
ref = ('const', r.desc())
other_r = inv_cid.get(ref, None)
if other_r is not None:
env.replace(r, other_r)
......@@ -397,20 +397,16 @@ class MergeOptimizer(Optimizer):
inv_cid[i] = r
for op in env.io_toposort():
# this could be made more robust by having an op.hash() that
# doesn't depend on the inputs but can depend on additional properties
# of the op.
op_cid = (op.__class__, tuple([cid[input] for input in op.inputs]))
op_cid = (op.desc(), tuple([cid[input] for input in op.inputs]))
dup = inv_cid.get(op_cid, None)
success = False
if dup is not None:
success = True
for output, other_output in zip(op.outputs, dup.outputs):
try:
env.replace(output, other_output)
except:
success = False
break
d = dict(zip(op.outputs, dup.outputs))
try:
env.replace_all(d)
except Exception, e:
success = False
if not success:
cid[op] = op_cid
inv_cid[op_cid] = op
......
......@@ -77,6 +77,9 @@ class ResultBase(object):
def __hash__(self):
return self._hash_id
def desc(self):
return id(self)
#
# role
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论