提交 dfe77ee8 authored 作者: James Bergstra's avatar James Bergstra

Added infer_shape() methods to several ops, to take advantage of ShapeOpt

optimizer. Also, moved MakeVector to tensor.opt since it is only used internally by ShapeOpt.
上级 b31bdb47
...@@ -1035,6 +1035,7 @@ class TensorConstantSignature(tuple): ...@@ -1035,6 +1035,7 @@ class TensorConstantSignature(tuple):
except: except:
return False return False
#N.B. compare shape to ensure no broadcasting in == #N.B. compare shape to ensure no broadcasting in ==
#N.B. compare elementwise last because it is the most expensive check
return (t0 == t1) and (d0.shape == d1.shape) \ return (t0 == t1) and (d0.shape == d1.shape) \
and (self.sum == other.sum) and (numpy.all(d0 == d1)) and (self.sum == other.sum) and (numpy.all(d0 == d1))
def __hash__(self): def __hash__(self):
...@@ -1300,9 +1301,15 @@ def shape(a): ...@@ -1300,9 +1301,15 @@ def shape(a):
pprint.assign(_shape, printing.MemberPrinter('shape')) pprint.assign(_shape, printing.MemberPrinter('shape'))
class MaxAndArgmax(Op): class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis""" """Calculate the max and argmax over a given axis.
.. note::
If axis is None it means to calculate the max over the last dimension which is
DIFFERENT FROM NUMPY!!
"""
nin=2 # tensor, axis nin=2 # tensor, axis
nout=2 # max val, max idx nout=2 # max val, max idx
E_axis = 'invalid axis' E_axis = 'invalid axis'
...@@ -1313,7 +1320,8 @@ class MaxAndArgmax(Op): ...@@ -1313,7 +1320,8 @@ class MaxAndArgmax(Op):
axis = x.type.ndim - 1 axis = x.type.ndim - 1
axis = _as_tensor_variable(axis) axis = _as_tensor_variable(axis)
inputs = [x, axis] inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative #TODO: figure things out if axis is a constant
broadcastable = [False] * (x.type.ndim - 1)
outputs = [tensor(x.type.dtype, broadcastable,name='max'), outputs = [tensor(x.type.dtype, broadcastable,name='max'),
tensor('int32', broadcastable,name='argmax')] tensor('int32', broadcastable,name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -2493,6 +2501,10 @@ class Join(Op): ...@@ -2493,6 +2501,10 @@ class Join(Op):
join(2, x, y, z) # WRONG: the axis has to be an index into the shape join(2, x, y, z) # WRONG: the axis has to be an index into the shape
join(0, x, u) # WRONG: joined tensors must have the same rank join(0, x, u) # WRONG: joined tensors must have the same rank
""" """
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *axis_and_tensors): def make_node(self, *axis_and_tensors):
""" """
...@@ -2769,37 +2781,6 @@ else: ...@@ -2769,37 +2781,6 @@ else:
pass pass
class MakeVector(Op):
"""WRITEME"""
def __init__(self, stype):
self.stype = stype
def make_node(self, *inputs):
inputs = map(as_tensor_variable, inputs)
assert all(a.type == self.stype for a in inputs)
return Apply(self, inputs, [TensorType(broadcastable = (False,),
dtype = self.stype.dtype)()])
def perform(self, node, inputs, (out,)):
out[0] = numpy.asarray(inputs)
def grad(self, inputs, (gout,)):
return [None]*len(inputs)
make_lvector = MakeVector(lscalar)
"""WRITEME"""
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
class Reshape(Op): class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp. """Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be known at graph The number of dimensions to which to reshape to (ndim) must be known at graph
...@@ -3343,6 +3324,18 @@ class Dot(Op): ...@@ -3343,6 +3324,18 @@ class Dot(Op):
rval = dot(gz, y.T), dot(x.T, gz) rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype) return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def infer_shape(self, node, (xshp,yshp), one):
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def __str__(self): def __str__(self):
return "dot" return "dot"
dot = Dot() dot = Dot()
......
...@@ -197,6 +197,18 @@ class DimShuffle(Op): ...@@ -197,6 +197,18 @@ class DimShuffle(Op):
storage[0] = numpy.asarray(res) #asarray puts scalars back into array storage[0] = numpy.asarray(res) #asarray puts scalars back into array
def infer_shape(self, node, (ishp,), one):
ishp = list(ishp)
for drop in reversed(self.drop):
del ishp[drop]
# transpose
rval = [ishp[i] for i in self.shuffle]
# augment
for augm in self.augment:
rval.insert(augm, one)
return [rval]
def c_code(self, node, name, (input,), (res,), sub): def c_code(self, node, name, (input,), (res,), sub):
basename = input + '__view_or_copy' basename = input + '__view_or_copy'
...@@ -613,6 +625,25 @@ class Elemwise(Op): ...@@ -613,6 +625,25 @@ class Elemwise(Op):
# the following should be used instead of the previous loop, unfortunately it tends to segfault # the following should be used instead of the previous loop, unfortunately it tends to segfault
# self.ufunc(*(ufunc_args+[s[0] for s in output_storage])) # self.ufunc(*(ufunc_args+[s[0] for s in output_storage]))
def infer_shape(self, node, i_shapes, one):
rval = []
for o in node.outputs:
oshp = []
for dim, b in enumerate(o.type.broadcastable):
b_dim = None
if b: # this is broadcastable
b_dim = one
else: # there must be some input that is not broadcastable
for ishp, i in zip(i_shapes,node.inputs):
if not i.type.broadcastable[dim]:
b_dim = ishp[dim]
assert b_dim, 'AA'
break
assert b_dim, 'BB'
oshp.append(b_dim)
rval.append(oshp)
return rval
def _c_all(self, node, name, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
_inames = inames _inames = inames
_onames = onames _onames = onames
...@@ -834,6 +865,13 @@ class CAReduce(Op): ...@@ -834,6 +865,13 @@ class CAReduce(Op):
else: else:
output[0] = numpy.copy(variable) output[0] = numpy.copy(variable)
def infer_shape(self, node, (ishape,), one):
axis = self.axis
if axis is None:
return (),
return [ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i not in axis],
def _c_all(self, node, name, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
input = node.inputs[0] input = node.inputs[0]
......
...@@ -136,10 +136,7 @@ class RandomFunction(gof.Op): ...@@ -136,10 +136,7 @@ class RandomFunction(gof.Op):
draw. draw.
""" """
if shape == () or shape == []: shape = tensor.as_tensor_variable(shape, ndim=1)
shape = tensor.as_tensor_variable(shape, dtype='int64')
else:
shape = tensor.as_tensor_variable(shape, ndim=1)
assert shape.type.ndim == 1 assert shape.type.ndim == 1
assert (shape.type.dtype == 'int64') or (shape.type.dtype == 'int32') assert (shape.type.dtype == 'int64') or (shape.type.dtype == 'int32')
if not isinstance(r.type, RandomStateType): if not isinstance(r.type, RandomStateType):
...@@ -158,6 +155,22 @@ class RandomFunction(gof.Op): ...@@ -158,6 +155,22 @@ class RandomFunction(gof.Op):
[r, shape] + args, [r, shape] + args,
[r.type(), self.outtype()]) [r.type(), self.outtype()])
def infer_shape(self, node, i_shapes, one):
r, shp = node.inputs[0:2]
#if shp is a constant array of len 0, then it means 'automatic shape'
unknown_shape = len(getattr(shp, 'data', [0,1,2])) == 0
# if ndim_added == 0 and shape != () then shape
if self.ndim_added == 0 and not unknown_shape:
sample_shp = shp
else:
# if shape == () then it will depend on args
# if ndim_added != 0 and shape != () then it will depend on args
sample_shp = node.outputs[1].shape
return [None, [sample_shp[i] for i in xrange(node.outputs[1].ndim)]]
def perform(self, node, inputs, (rout, out)): def perform(self, node, inputs, (rout, out)):
# Use self.fn to draw shape worth of random numbers. # Use self.fn to draw shape worth of random numbers.
# Numbers are drawn from r if self.inplace is True, and from a copy of r if # Numbers are drawn from r if self.inplace is True, and from a copy of r if
......
...@@ -998,6 +998,21 @@ def test_local_fill_useless(): ...@@ -998,6 +998,21 @@ def test_local_fill_useless():
f = function([x,y], T.fill(x,y)*2, mode=m) f = function([x,y], T.fill(x,y)*2, mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.mul] assert [node.op for node in f.maker.env.toposort()] == [T.mul]
class test_shapeoptimizer(unittest.TestCase):
def test0(self):
v = T.vector()
m = T.matrix()
f = function([v,m], (v+m).shape)
for node in f.maker.env.toposort():
assert node.op != T.add
def test_constant(self):
v = T.vector()
m = T.matrix()
f = function([v,m], v.dimshuffle('x','x',0).shape[1])
print f.maker.env.toposort()
assert [] == f.maker.env.toposort()
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论