提交 5e19f60b authored 作者: James Bergstra's avatar James Bergstra

ShapeOpt - added doc about infer_shape() and removed the extra 'one' parameter

of infer_shape
上级 dfe77ee8
......@@ -3324,7 +3324,7 @@ class Dot(Op):
rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def infer_shape(self, node, (xshp,yshp), one):
def infer_shape(self, node, (xshp,yshp)):
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
......
......@@ -197,7 +197,7 @@ class DimShuffle(Op):
storage[0] = numpy.asarray(res) #asarray puts scalars back into array
def infer_shape(self, node, (ishp,), one):
def infer_shape(self, node, (ishp,)):
ishp = list(ishp)
for drop in reversed(self.drop):
del ishp[drop]
......@@ -206,7 +206,7 @@ class DimShuffle(Op):
# augment
for augm in self.augment:
rval.insert(augm, one)
rval.insert(augm, 1)
return [rval]
def c_code(self, node, name, (input,), (res,), sub):
......@@ -625,14 +625,14 @@ class Elemwise(Op):
# the following should be used instead of the previous loop, unfortunately it tends to segfault
# self.ufunc(*(ufunc_args+[s[0] for s in output_storage]))
def infer_shape(self, node, i_shapes, one):
def infer_shape(self, node, i_shapes):
rval = []
for o in node.outputs:
oshp = []
for dim, b in enumerate(o.type.broadcastable):
b_dim = None
if b: # this is broadcastable
b_dim = one
b_dim = 1
else: # there must be some input that is not broadcastable
for ishp, i in zip(i_shapes,node.inputs):
if not i.type.broadcastable[dim]:
......@@ -865,7 +865,7 @@ class CAReduce(Op):
else:
output[0] = numpy.copy(variable)
def infer_shape(self, node, (ishape,), one):
def infer_shape(self, node, (ishape,)):
axis = self.axis
if axis is None:
return (),
......
......@@ -268,6 +268,22 @@ class ShapeOptimizer(Optimizer):
extra computations make it appear as if many internal graph nodes have multiple clients.
Many optimizations refuse to work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is present, or else using a
conservative default. An Op that supports shape-lifting should define a
infer_shape(self, node, input_shapes) function. The argument input_shapes is a tuple
of tuples... there is an interior tuple for each input to the node. The tuple has as many
elements as dimensions. The element in position i of tuple j represents the i'th shape
component of the j'th input. The function should return a tuple of tuples. One output
tuple for each node.output. Again, the i'th element of the j'th output tuple represents
the output[j].shape[i] of the function. If an output is not a TensorType, then None should
be returned instead of a tuple for that output.
For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
infer_sha
Infering the shape of internal nodes in the graph is important for doing size-driven
optimizations. If we know how big various intermediate results will be, we can estimate
the cost of many Ops accurately, and generate c-code that is specific [e.g. unrolled] to
......@@ -299,6 +315,9 @@ class ShapeOptimizer(Optimizer):
def unpack(s_i):
# unpack the s_i that the Op returned
assert s_i is not None
if s_i == 1:
# don't make the optimizer merge a zillion ones together
return lscalar_one
if type(s_i) is int:
# this shape is a constant
assert s_i >= 0
......@@ -323,7 +342,7 @@ class ShapeOptimizer(Optimizer):
else:
shape_of[r] = tuple([unpack(s_i) for s_i in s])
def default_infer_shape(node, i_shapes, lscalar_one):
def default_infer_shape(node, i_shapes):
rval = []
for r in node.outputs:
try:
......@@ -350,13 +369,12 @@ class ShapeOptimizer(Optimizer):
shape_infer = default_infer_shape
try:
o_shapes = shape_infer(node, [shape_of[r] for r in node.inputs], lscalar_one)
o_shapes = shape_infer(node, [shape_of[r] for r in node.inputs])
except Exception, e:
_logger.error('Failed to infer_shape from Op %s (i_shapes=%s): %s %s'% (node.op,
[shape_of[r] for r in node.inputs],
type(e), str(e)))
o_shapes = default_infer_shape(node, [shape_of[r] for r in node.inputs],
lscalar_one)
o_shapes = default_infer_shape(node, [shape_of[r] for r in node.inputs])
# this is packed information
# an element of o_shapes is either None or a tuple
......
......@@ -155,7 +155,7 @@ class RandomFunction(gof.Op):
[r, shape] + args,
[r.type(), self.outtype()])
def infer_shape(self, node, i_shapes, one):
def infer_shape(self, node, i_shapes):
r, shp = node.inputs[0:2]
#if shp is a constant array of len 0, then it means 'automatic shape'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论