提交 961d94ce authored 作者: James Bergstra's avatar James Bergstra

ShapeFeature DRAFT of how to update shape_i's when variables are replaced during

optimizations.
上级 ba72dc0e
......@@ -402,6 +402,7 @@ class ShapeFeature(object):
assert not hasattr(env, 'shape_feature')
env.shape_feature = self
self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.scheduled = {} # Variable ->
self.lscalar_one = T.constant(1, dtype='int64')
assert self.lscalar_one.type == T.lscalar
for node in env.toposort():
......@@ -448,7 +449,24 @@ class ShapeFeature(object):
# TODO:
# This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do.
pass
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
# In case 1, if r has a shape_i client, we will want to replace the shape_i of r with
# the shape of new_r. Say that r is *scheduled*.
for (shpnode, idx) in r.clients:
if isinstance(getattr(shpnode,'op', None), Shape_i):
self.scheduled[shpnode] = new_r
print >> sys.stderr, 'SCHEDULING SOMETHING', self.scheduled
# In case 2, if new_r is a variable that we've scheduled for shape update, then we
# should cancel it.
# TODO: store some kind of reverse index?
for k,v in self.scheduled.items():
if v == r:
del self.scheduled[k]
print>> sys.stderr, 'UNSCHEDULING SOMETHING', self.scheduled
class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an env feature.
......@@ -534,6 +552,20 @@ def local_shape_to_shape_i(node):
shape_feature = node.env.shape_feature
return [shape_feature.make_vector_shape(node.inputs[0])]
@register_specialize
@register_canonicalize
@gof.local_optimizer([T._shape])
def local_track_shape_i(node):
try:
shape_feature = node.env.shape_feature
except:
return
if node in node.env.shape_feature.scheduled:
assert isinstance(node.op, Shape_i)
replacement = node.env.shape_feature.scheduled[node.inputs[0]]
print >> sys.stderr, "REPLACING SOMETHING"
return [node.env.shape_feature.shape_of[replacement][i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Subtensor])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论