提交 e926c476 authored 作者: James Bergstra's avatar James Bergstra

adding comments and TODOs to ShapeFeature

上级 3540ba9e
...@@ -629,18 +629,23 @@ class ShapeFeature(object): ...@@ -629,18 +629,23 @@ class ShapeFeature(object):
""" """
def shape_ir(self, i, r): def shape_ir(self, i, r):
#TODO: Write a doc string for this method """Return symbolic r.shape[i] for tensor variable r, int i"""
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]: if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] return Shape_i(i).make_node(r).outputs[0]
def shape_tuple(self, r): def shape_tuple(self, r):
#TODO: Write a doc string for this method """Return a tuple of symbolic shape vars for tensor variable r"""
return tuple([self.shape_ir(i,r) for i in xrange(r.ndim)]) return tuple([self.shape_ir(i,r) for i in xrange(r.ndim)])
def default_infer_shape(self, node, i_shapes): def default_infer_shape(self, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
"""
rval = [] rval = []
for r in node.outputs: for r in node.outputs:
try: try:
...@@ -650,16 +655,21 @@ class ShapeFeature(object): ...@@ -650,16 +655,21 @@ class ShapeFeature(object):
return rval return rval
def unpack(self, s_i): def unpack(self, s_i):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
"""
# unpack the s_i that the Op returned # unpack the s_i that the Op returned
assert s_i is not None assert s_i is not None
if s_i == 1: if s_i == 1:
# don't make the optimizer merge a zillion ones together # don't make the optimizer merge a zillion ones together
# by always returning the same object to represent 1
return self.lscalar_one return self.lscalar_one
if type(s_i) in (int,long) or isinstance(s_i, numpy.integer): if type(s_i) in (int,long) or isinstance(s_i, numpy.integer):
# this shape is a constant # this shape is a constant
assert s_i >= 0 assert s_i >= 0
return T.constant(s_i, dtype='int64') return T.constant(s_i, dtype='int64')
if type(s_i) in (tuple,list): if type(s_i) in (tuple, list):
# this dimension is the same as many of the inputs # this dimension is the same as many of the inputs
# which tells us that if one of the inputs is known, # which tells us that if one of the inputs is known,
# the others all become known. # the others all become known.
...@@ -676,12 +686,19 @@ class ShapeFeature(object): ...@@ -676,12 +686,19 @@ class ShapeFeature(object):
s_i, type(s_i), getattr(s_i, 'type', None)) s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s): def set_shape(self, r, s):
"""Assign the shape `s` to previously un-shaped variable `r`.
:type r: a variable
:type s: None or a tuple of symbolic integers
"""
assert r not in self.shape_of, 'r already in shape_of' assert r not in self.shape_of, 'r already in shape_of'
if s is None: if s is None:
self.shape_of[r] = s self.shape_of[r] = s
else: else:
self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s]) self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s])
# XXX: add a reverse index from the tuple elements -> r
def update_shape(self, r, other_r): def update_shape(self, r, other_r):
'''Replace shape of r by shape of other_r. '''Replace shape of r by shape of other_r.
...@@ -697,10 +714,14 @@ class ShapeFeature(object): ...@@ -697,10 +714,14 @@ class ShapeFeature(object):
else: else:
# If no info is known on r's shape, use other_shape # If no info is known on r's shape, use other_shape
self.shape_of[r] = other_shape self.shape_of[r] = other_shape
#XXX: add reverse index from elements of other_shape -> r
return return
# If other_shape has no information, use r_shape # If other_shape has no information, call is pointless.
# XXX: move this above the previous if/else block
if other_shape is None: if other_shape is None:
# XXX: no need to assign back, delete following line
self.shape_of[r] = r_shape self.shape_of[r] = r_shape
return return
...@@ -719,6 +740,7 @@ class ShapeFeature(object): ...@@ -719,6 +740,7 @@ class ShapeFeature(object):
else: else:
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
self.shape_of[r] = tuple(merged_shape) self.shape_of[r] = tuple(merged_shape)
# XXX: update reverse index
def set_shape_i(self, r, i, s_i): def set_shape_i(self, r, i, s_i):
'''Replace element i of shape_of[r] by s_i''' '''Replace element i of shape_of[r] by s_i'''
...@@ -733,13 +755,15 @@ class ShapeFeature(object): ...@@ -733,13 +755,15 @@ class ShapeFeature(object):
else: else:
new_shape.append(s_j) new_shape.append(s_j)
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
# XXX: update reverse index
def init_r(self, r): def init_r(self, r):
'''Register r's shape in the shape_of dictionary.''' '''Register r's shape in the shape_of dictionary.'''
if r not in self.shape_of: if r not in self.shape_of:
try: try:
self.set_shape(r, self.shape_tuple(r)) self.set_shape(r, self.shape_tuple(r))
except AttributeError: # XXX: update reverse index
except AttributeError: #XXX: where would this come from?
self.set_shape(r,None) self.set_shape(r,None)
def make_vector_shape(self, r): def make_vector_shape(self, r):
...@@ -759,6 +783,7 @@ class ShapeFeature(object): ...@@ -759,6 +783,7 @@ class ShapeFeature(object):
self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple) self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.scheduled = {} # Variable -> self.scheduled = {} # Variable ->
# XXX: create reverse index
for node in env.toposort(): for node in env.toposort():
self.on_import(env, node) self.on_import(env, node)
...@@ -798,9 +823,11 @@ class ShapeFeature(object): ...@@ -798,9 +823,11 @@ class ShapeFeature(object):
# this is packed information # this is packed information
# an element of o_shapes is either None or a tuple # an element of o_shapes is either None or a tuple
# elements of the tuple can be either strings, or ints # elements of the tuple can be either strings, or ints
if len(o_shapes) != len(node.outputs): if len(o_shapes) != len(node.outputs):
raise Exception('len(o_shapes) = '+str(len(o_shapes))+' != len(node.outputs) = '+str(len(node.outputs))) raise Exception('len(o_shapes) = '
+ str(len(o_shapes))
+ ' != len(node.outputs) = '
+ str(len(node.outputs)))
for r, s in zip(node.outputs, o_shapes): for r, s in zip(node.outputs, o_shapes):
self.set_shape(r, s) self.set_shape(r, s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论