提交 4d667a6d authored 作者: David Warde-Farley's avatar David Warde-Farley

PEP8 and docstring fixes.

上级 6dde460c
......@@ -4618,6 +4618,7 @@ class Reshape(Op):
oshape.append(os_i)
return [tuple(oshape)]
def reshape(x, newshape, ndim=None, name=None):
if ndim is None:
ndim = get_vector_length(newshape)
......@@ -4625,24 +4626,34 @@ def reshape(x, newshape, ndim=None, name=None):
rval = op(x, newshape)
return rval
class Flatten(Op):
"""Flattens a tensor to `outdim` dimensions by preserving the leading outdim-1 shape
components.
"""
view_map = {0:[0]}
Flattens a tensor to `outdim` dimensions by preserving the leading
outdim - 1 shape components.
"""
view_map = {0: [0]}
def __init__(self, outdim=1):
self.outdim = int(outdim)
def __eq__(self, other):
return type(self) == type(other) and self.outdim == other.outdim
def __hash__(self):
return hashtype(self)^hash(self.outdim)
return hashtype(self) ^ hash(self.outdim)
def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.outdim)
def make_node(self, x):
t_x = as_tensor_variable(x)
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions(%i) for tensor of rank %i' %(self.outdim, t_x.ndim))
return gof.Apply(self, [t_x], [tensor(x.type.dtype, (False,)*self.outdim)])
raise ValueError('invalid output ndimensions (%i) for tensor of '
'rank %i' % (self.outdim, t_x.ndim))
return gof.Apply(self, [t_x], [tensor(x.type.dtype,
(False,) * self.outdim)])
def perform(self, node, inp, out_):
x, = inp
out, = out_
......@@ -4655,9 +4666,10 @@ class Flatten(Op):
elif outdim == len(x.shape):
out[0] = x
else:
newshape = x.shape[:outdim-1] + (numpy.prod(x.shape[outdim-1:]),)
#print 'newshape', newshape, x.shape, x.shape
newshape = (x.shape[:outdim - 1] +
(numpy.prod(x.shape[outdim - 1:]),))
out[0] = x.reshape(newshape)
def grad(self, inp, grads):
x, = inp
g_out, = grads
......@@ -4668,23 +4680,29 @@ class Flatten(Op):
return [None]
return self.make_node(*eval_points).outputs
def flatten(x, outdim=1):
return Flatten(outdim)(x)
class TileGrad(Op):
"""Calculates the gradient of the Tile Op"""
"""
Calculates the gradient of the Tile Op.
"""
#this is so weird, I can't think of how to make this a general thing.
def make_node(self, x, reps, g_out):
return gof.Apply(self, [x, reps, g_out], [x.type()])
def perform(self, node, inp, out):
x, reps, g_out = inp
gx, = out
xsh = x.shape
if len(reps)==2 and reps[1] == 1 and len(x.shape) == 1:
if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
gx[0] = numpy.sum(g_out, axis=0)
else:
raise NotImplementedError('x.shape, reps combination not supported',
(x.shape, reps))
raise NotImplementedError('x.shape, reps combination not'
'supported', (x.shape, reps))
tilegrad = TileGrad()
......@@ -4692,43 +4710,56 @@ class Tile(Op):
"""
Construct an array by repeating the input x according to reps pattern.
Tiles its input according to reps. The len of reps is the number of
dimension of x and contains the number of times to tile x in each dimension.
Tiles its input according to reps. The length of reps is the number of
dimension of x and contains the number of times to tile x in each
dimension.
:see: `numpy.tile http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html`_
:see: `numpy.tile
<http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html>`_
"""
def __init__(self, ndim):
self.ndim = ndim
def __eq__(self, other):
return (type(other) is Tile) and (other.ndim == self.ndim)
def __hash__(self):
return hash(Tile) ^ hash(self.ndim)
def make_node(self, x, reps):
x = as_tensor_variable(x)
reps = as_tensor_variable(reps)
return gof.Apply(self, [x, reps], [tensor(x.type.dtype, [False,] * self.ndim)])
return gof.Apply(self, [x, reps], [tensor(x.type.dtype, [False] *
self.ndim)])
def perform(self, node, inp, out_):
x, reps = inp
out, = out_
out[0] = numpy.tile(x, reps)
if len(out[0].shape) != self.ndim:
raise ValueError('Tile.perform produced incorrect shape')
def grad(self, inp, grads):
x, reps = inp
g_out, = grads
return [tilegrad(x, reps, g_out), None]
def tile(x, reps, ndim=None):
"""
Tile input array `x` according to `reps`. See the docstring of `numpy.tile`
for details.
TODO: expand this.
"""
if not hasattr(tile, 'op'):
tile.op = {}
if ndim is None:
ndim = len(reps)
#backport
#ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work.
# backport
# ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going
# to work.
if ndim not in tile.op:
tile.op[ndim] = Tile(ndim)
return tile.op[ndim](x, reps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论