提交 dfe77ee8 authored 作者: James Bergstra's avatar James Bergstra

Added infer_shape() methods to several ops, to take advantage of ShapeOpt

optimizer. Also, moved MakeVector to tensor.opt since it is only used internally by ShapeOpt.
上级 b31bdb47
......@@ -1035,6 +1035,7 @@ class TensorConstantSignature(tuple):
except:
return False
#N.B. compare shape to ensure no broadcasting in ==
#N.B. compare elementwise last because it is the most expensive check
return (t0 == t1) and (d0.shape == d1.shape) \
and (self.sum == other.sum) and (numpy.all(d0 == d1))
def __hash__(self):
......@@ -1300,9 +1301,15 @@ def shape(a):
pprint.assign(_shape, printing.MemberPrinter('shape'))
class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis"""
"""Calculate the max and argmax over a given axis.
.. note::
If axis is None it means to calculate the max over the last dimension which is
DIFFERENT FROM NUMPY!!
"""
nin=2 # tensor, axis
nout=2 # max val, max idx
E_axis = 'invalid axis'
......@@ -1313,7 +1320,8 @@ class MaxAndArgmax(Op):
axis = x.type.ndim - 1
axis = _as_tensor_variable(axis)
inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative
#TODO: figure things out if axis is a constant
broadcastable = [False] * (x.type.ndim - 1)
outputs = [tensor(x.type.dtype, broadcastable,name='max'),
tensor('int32', broadcastable,name='argmax')]
return Apply(self, inputs, outputs)
......@@ -2493,6 +2501,10 @@ class Join(Op):
join(2, x, y, z) # WRONG: the axis has to be an index into the shape
join(0, x, u) # WRONG: joined tensors must have the same rank
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *axis_and_tensors):
"""
......@@ -2769,37 +2781,6 @@ else:
pass
class MakeVector(Op):
"""WRITEME"""
def __init__(self, stype):
self.stype = stype
def make_node(self, *inputs):
inputs = map(as_tensor_variable, inputs)
assert all(a.type == self.stype for a in inputs)
return Apply(self, inputs, [TensorType(broadcastable = (False,),
dtype = self.stype.dtype)()])
def perform(self, node, inputs, (out,)):
out[0] = numpy.asarray(inputs)
def grad(self, inputs, (gout,)):
return [None]*len(inputs)
make_lvector = MakeVector(lscalar)
"""WRITEME"""
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be known at graph
......@@ -3343,6 +3324,18 @@ class Dot(Op):
rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def infer_shape(self, node, (xshp,yshp), one):
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def __str__(self):
return "dot"
dot = Dot()
......
......@@ -197,6 +197,18 @@ class DimShuffle(Op):
storage[0] = numpy.asarray(res) #asarray puts scalars back into array
def infer_shape(self, node, (ishp,), one):
ishp = list(ishp)
for drop in reversed(self.drop):
del ishp[drop]
# transpose
rval = [ishp[i] for i in self.shuffle]
# augment
for augm in self.augment:
rval.insert(augm, one)
return [rval]
def c_code(self, node, name, (input,), (res,), sub):
basename = input + '__view_or_copy'
......@@ -613,6 +625,25 @@ class Elemwise(Op):
# the following should be used instead of the previous loop, unfortunately it tends to segfault
# self.ufunc(*(ufunc_args+[s[0] for s in output_storage]))
def infer_shape(self, node, i_shapes, one):
rval = []
for o in node.outputs:
oshp = []
for dim, b in enumerate(o.type.broadcastable):
b_dim = None
if b: # this is broadcastable
b_dim = one
else: # there must be some input that is not broadcastable
for ishp, i in zip(i_shapes,node.inputs):
if not i.type.broadcastable[dim]:
b_dim = ishp[dim]
assert b_dim, 'AA'
break
assert b_dim, 'BB'
oshp.append(b_dim)
rval.append(oshp)
return rval
def _c_all(self, node, name, inames, onames, sub):
_inames = inames
_onames = onames
......@@ -834,6 +865,13 @@ class CAReduce(Op):
else:
output[0] = numpy.copy(variable)
def infer_shape(self, node, (ishape,), one):
axis = self.axis
if axis is None:
return (),
return [ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i not in axis],
def _c_all(self, node, name, inames, onames, sub):
input = node.inputs[0]
......
......@@ -136,10 +136,7 @@ class RandomFunction(gof.Op):
draw.
"""
if shape == () or shape == []:
shape = tensor.as_tensor_variable(shape, dtype='int64')
else:
shape = tensor.as_tensor_variable(shape, ndim=1)
shape = tensor.as_tensor_variable(shape, ndim=1)
assert shape.type.ndim == 1
assert (shape.type.dtype == 'int64') or (shape.type.dtype == 'int32')
if not isinstance(r.type, RandomStateType):
......@@ -158,6 +155,22 @@ class RandomFunction(gof.Op):
[r, shape] + args,
[r.type(), self.outtype()])
def infer_shape(self, node, i_shapes, one):
r, shp = node.inputs[0:2]
#if shp is a constant array of len 0, then it means 'automatic shape'
unknown_shape = len(getattr(shp, 'data', [0,1,2])) == 0
# if ndim_added == 0 and shape != () then shape
if self.ndim_added == 0 and not unknown_shape:
sample_shp = shp
else:
# if shape == () then it will depend on args
# if ndim_added != 0 and shape != () then it will depend on args
sample_shp = node.outputs[1].shape
return [None, [sample_shp[i] for i in xrange(node.outputs[1].ndim)]]
def perform(self, node, inputs, (rout, out)):
# Use self.fn to draw shape worth of random numbers.
# Numbers are drawn from r if self.inplace is True, and from a copy of r if
......
......@@ -998,6 +998,21 @@ def test_local_fill_useless():
f = function([x,y], T.fill(x,y)*2, mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.mul]
class test_shapeoptimizer(unittest.TestCase):
def test0(self):
v = T.vector()
m = T.matrix()
f = function([v,m], (v+m).shape)
for node in f.maker.env.toposort():
assert node.op != T.add
def test_constant(self):
v = T.vector()
m = T.matrix()
f = function([v,m], v.dimshuffle('x','x',0).shape[1])
print f.maker.env.toposort()
assert [] == f.maker.env.toposort()
if __name__ == '__main__':
# unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论