提交 2813506a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add more shape-tracking possibilities to ShapeFeature

上级 f95dc2d5
...@@ -602,13 +602,64 @@ class ShapeFeature(object): ...@@ -602,13 +602,64 @@ class ShapeFeature(object):
s_i, type(s_i), getattr(s_i, 'type', None)) s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s): def set_shape(self, r, s):
assert r not in self.shape_of assert r not in self.shape_of, 'r already in shape_of'
if s is None: if s is None:
self.shape_of[r] = s self.shape_of[r] = s
else: else:
self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s]) self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s])
def init_r(self,r): def update_shape(self, r, other_r):
'''Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions.
'''
# other_r should already have a shape
assert other_r in self.shape_of, ('other_r not in shape_of', other_r)
other_shape = self.shape_of[other_r]
# If no info is known on r's shape, use other_shape
try:
r_shape = self.shape_tuple(r)
except AttributeError, e:
#print e
self.shape_of[r] = other_shape
return
# If other_shape has no information, use r_shape
if other_shape is None:
self.shape_of[r] = r_shape
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
for i, ps in enumerate(other_shape):
# If other_shape[i] is uninformative (if it is just
# Shape_i(i)(other_r)), use r_shape[i]
if (isinstance(getattr(ps,'op',None), Shape_i) and
ps.i == i and
ps.inputs[0] == other_r):
merged_shape.append(r_shape[i])
else:
merged_shape.append(other_shape[i])
self.shape_of[r] = tuple(merged_shape)
def set_shape_i(self, r, i, s_i):
'''Replace element i of shape_of[r] by s_i'''
assert r in self.shape_of
prev_shape = self.shape_of[r]
# prev_shape is a tuple, so we cannot change it inplace,
# so we build another one.
new_shape = []
for j, s_j in enumerate(prev_shape):
if j == i:
new_shape.append(self.unpack(s_i))
else:
new_shape.append(s_j)
self.shape_of[r] = tuple(new_shape)
def init_r(self, r):
'''Register r's shape in the shape_of dictionary.'''
if r not in self.shape_of: if r not in self.shape_of:
try: try:
self.set_shape(r, self.shape_tuple(r)) self.set_shape(r, self.shape_tuple(r))
...@@ -619,7 +670,7 @@ class ShapeFeature(object): ...@@ -619,7 +670,7 @@ class ShapeFeature(object):
return make_vector(*self.shape_of[r]) return make_vector(*self.shape_of[r])
# #
# #
# Feature inteface # Feature interface
# #
# #
def on_attach(self, env): def on_attach(self, env):
...@@ -669,10 +720,10 @@ class ShapeFeature(object): ...@@ -669,10 +720,10 @@ class ShapeFeature(object):
self.set_shape(r, s) self.set_shape(r, s)
def on_change_input(self, env, node, i, r, new_r): def on_change_input(self, env, node, i, r, new_r):
# TODO:
# This tells us that r and new_r must have the same shape # This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do. # if we didn't know that the shapes are related, now we do.
self.init_r(new_r) self.update_shape(new_r, r)
# change_input happens in two cases: # change_input happens in two cases:
# 1) we are trying to get rid of r, or # 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction. # 2) we are putting things back after a failed transaction.
...@@ -690,6 +741,15 @@ class ShapeFeature(object): ...@@ -690,6 +741,15 @@ class ShapeFeature(object):
if v == r: if v == r:
del self.scheduled[k] del self.scheduled[k]
# In either case, r could be in shape_of.values(), that is, r itself
# is the shape of something. In that case, we want to update
# the value in shape_of, to keep it up-to-date.
for k,v in self.shape_of.iteritems():
if v is not None:
for ii, vi in enumerate(v):
if vi == r:
self.set_shape_i(k, ii, new_r)
class ShapeOptimizer(Optimizer): class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an env feature. """Optimizer that serves to add ShapeFeature as an env feature.
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论