提交 352f0040 authored 作者: James Bergstra's avatar James Bergstra

added reverse index to ShapeFeature to make it much faster during graph optimizations

上级 e926c476
...@@ -695,9 +695,10 @@ class ShapeFeature(object): ...@@ -695,9 +695,10 @@ class ShapeFeature(object):
if s is None: if s is None:
self.shape_of[r] = s self.shape_of[r] = s
else: else:
self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s]) shape_vars = [self.unpack(s_i) for s_i in s]
self.shape_of[r] = tuple(shape_vars)
# XXX: add a reverse index from the tuple elements -> r for sv in shape_vars:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def update_shape(self, r, other_r): def update_shape(self, r, other_r):
'''Replace shape of r by shape of other_r. '''Replace shape of r by shape of other_r.
...@@ -709,20 +710,17 @@ class ShapeFeature(object): ...@@ -709,20 +710,17 @@ class ShapeFeature(object):
assert other_r in self.shape_of, ('other_r not in shape_of', other_r) assert other_r in self.shape_of, ('other_r not in shape_of', other_r)
other_shape = self.shape_of[other_r] other_shape = self.shape_of[other_r]
# If other_shape has no information, call is pointless.
if other_shape is None:
return
if r in self.shape_of: if r in self.shape_of:
r_shape = self.shape_of[r] r_shape = self.shape_of[r]
else: else:
# If no info is known on r's shape, use other_shape # If no info is known on r's shape, use other_shape
self.shape_of[r] = other_shape self.shape_of[r] = other_shape
for sv in other_shape:
#XXX: add reverse index from elements of other_shape -> r self.shape_of_reverse_index.setdefault(sv, set()).add(r)
return
# If other_shape has no information, call is pointless.
# XXX: move this above the previous if/else block
if other_shape is None:
# XXX: no need to assign back, delete following line
self.shape_of[r] = r_shape
return return
# Merge other_shape with r_shape, giving the priority to other_shape # Merge other_shape with r_shape, giving the priority to other_shape
...@@ -732,15 +730,16 @@ class ShapeFeature(object): ...@@ -732,15 +730,16 @@ class ShapeFeature(object):
# For now, we consider 2 cases of uninformative other_shape[i]: # For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r); # - Shape_i(i)(other_r);
# - Shape_i(i)(r). # - Shape_i(i)(r).
if (ps.owner and if (ps.owner
isinstance(getattr(ps.owner,'op',None), Shape_i) and and isinstance(getattr(ps.owner, 'op', None), Shape_i)
ps.owner.op.i == i and and ps.owner.op.i == i
ps.owner.inputs[0] in (r, other_r)): and ps.owner.inputs[0] in (r, other_r)):
merged_shape.append(r_shape[i]) merged_shape.append(r_shape[i])
else: else:
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
self.shape_of[r] = tuple(merged_shape) self.shape_of[r] = tuple(merged_shape)
# XXX: update reverse index for sv in self.shape_of[r]:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def set_shape_i(self, r, i, s_i): def set_shape_i(self, r, i, s_i):
'''Replace element i of shape_of[r] by s_i''' '''Replace element i of shape_of[r] by s_i'''
...@@ -755,16 +754,16 @@ class ShapeFeature(object): ...@@ -755,16 +754,16 @@ class ShapeFeature(object):
else: else:
new_shape.append(s_j) new_shape.append(s_j)
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
# XXX: update reverse index for sv in self.shape_of[r]:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def init_r(self, r): def init_r(self, r):
'''Register r's shape in the shape_of dictionary.''' '''Register r's shape in the shape_of dictionary.'''
if r not in self.shape_of: if r not in self.shape_of:
try: try:
self.set_shape(r, self.shape_tuple(r)) self.set_shape(r, self.shape_tuple(r))
# XXX: update reverse index
except AttributeError: #XXX: where would this come from? except AttributeError: #XXX: where would this come from?
self.set_shape(r,None) self.set_shape(r, None)
def make_vector_shape(self, r): def make_vector_shape(self, r):
return make_vector(*self.shape_of[r]) return make_vector(*self.shape_of[r])
...@@ -781,9 +780,15 @@ class ShapeFeature(object): ...@@ -781,9 +780,15 @@ class ShapeFeature(object):
self.lscalar_one = T.constant(1, dtype='int64') self.lscalar_one = T.constant(1, dtype='int64')
assert self.lscalar_one.type == T.lscalar assert self.lscalar_one.type == T.lscalar
self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple) self.shape_of = {}
self.scheduled = {} # Variable -> # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
# XXX: create reverse index
self.scheduled = {}
# Variable ->
self.shape_of_reverse_index = {}
# shape var -> graph v
for node in env.toposort(): for node in env.toposort():
self.on_import(env, node) self.on_import(env, node)
...@@ -845,23 +850,28 @@ class ShapeFeature(object): ...@@ -845,23 +850,28 @@ class ShapeFeature(object):
# the shape of new_r. Say that r is *scheduled*. # the shape of new_r. Say that r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r # At that point, node is no longer a client of r, but of new_r
for (shpnode, idx) in (r.clients + [(node, i)]): for (shpnode, idx) in (r.clients + [(node, i)]):
if isinstance(getattr(shpnode,'op', None), Shape_i): if isinstance(getattr(shpnode, 'op', None), Shape_i):
self.scheduled[shpnode] = new_r self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update, then we # In case 2, if r is a variable that we've scheduled for shape update, then we
# should cancel it. # should cancel it.
# TODO: store some kind of reverse index? unscheduled = [k for k, v in self.scheduled.items() if v == r]
for k,v in self.scheduled.items(): for k in unscheduled:
if v == r:
del self.scheduled[k] del self.scheduled[k]
# In either case, r could be in shape_of.values(), that is, r itself # In either case, r could be in shape_of.values(), that is, r itself
# is the shape of something. In that case, we want to update # is the shape of something. In that case, we want to update
# the value in shape_of, to keep it up-to-date. # the value in shape_of, to keep it up-to-date.
for k,v in self.shape_of.iteritems(): for v in self.shape_of_reverse_index.get(r, []):
if v is not None: # The reverse index is only approximate. It is not updated on
for ii, vi in enumerate(v): # deletion of variables, or on change_input so it might be the
if vi == r: # case that there are a few extra `v`'s in it that no longer have
self.set_shape_i(k, ii, new_r) # a shape of r or possibly have been deleted from shape_of
# entirely. The important thing is that it permits to recall
# all variables with r in their shape.
for ii, svi in enumerate(self.shape_of.get(v, [])):
if svi == r:
self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set()
class ShapeOptimizer(Optimizer): class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an env feature. """Optimizer that serves to add ShapeFeature as an env feature.
...@@ -953,6 +963,7 @@ def local_track_shape_i(node): ...@@ -953,6 +963,7 @@ def local_track_shape_i(node):
if node in shape_feature.scheduled: if node in shape_feature.scheduled:
assert isinstance(node.op, Shape_i) assert isinstance(node.op, Shape_i)
replacement = shape_feature.scheduled[node] replacement = shape_feature.scheduled[node]
# XXX: what the heck is up with node.op.i ???
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize @register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论