提交 4d667a6d authored 作者: David Warde-Farley's avatar David Warde-Farley

PEP8 and docstring fixes.

上级 6dde460c
...@@ -4618,6 +4618,7 @@ class Reshape(Op): ...@@ -4618,6 +4618,7 @@ class Reshape(Op):
oshape.append(os_i) oshape.append(os_i)
return [tuple(oshape)] return [tuple(oshape)]
def reshape(x, newshape, ndim=None, name=None): def reshape(x, newshape, ndim=None, name=None):
if ndim is None: if ndim is None:
ndim = get_vector_length(newshape) ndim = get_vector_length(newshape)
...@@ -4625,24 +4626,34 @@ def reshape(x, newshape, ndim=None, name=None): ...@@ -4625,24 +4626,34 @@ def reshape(x, newshape, ndim=None, name=None):
rval = op(x, newshape) rval = op(x, newshape)
return rval return rval
class Flatten(Op): class Flatten(Op):
"""Flattens a tensor to `outdim` dimensions by preserving the leading outdim-1 shape
components.
""" """
view_map = {0:[0]} Flattens a tensor to `outdim` dimensions by preserving the leading
outdim - 1 shape components.
"""
view_map = {0: [0]}
def __init__(self, outdim=1): def __init__(self, outdim=1):
self.outdim = int(outdim) self.outdim = int(outdim)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.outdim == other.outdim return type(self) == type(other) and self.outdim == other.outdim
def __hash__(self): def __hash__(self):
return hashtype(self)^hash(self.outdim) return hashtype(self) ^ hash(self.outdim)
def __str__(self): def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.outdim) return '%s{%s}' % (self.__class__.__name__, self.outdim)
def make_node(self, x): def make_node(self, x):
t_x = as_tensor_variable(x) t_x = as_tensor_variable(x)
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim): if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions(%i) for tensor of rank %i' %(self.outdim, t_x.ndim)) raise ValueError('invalid output ndimensions (%i) for tensor of '
return gof.Apply(self, [t_x], [tensor(x.type.dtype, (False,)*self.outdim)]) 'rank %i' % (self.outdim, t_x.ndim))
return gof.Apply(self, [t_x], [tensor(x.type.dtype,
(False,) * self.outdim)])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, = inp x, = inp
out, = out_ out, = out_
...@@ -4655,9 +4666,10 @@ class Flatten(Op): ...@@ -4655,9 +4666,10 @@ class Flatten(Op):
elif outdim == len(x.shape): elif outdim == len(x.shape):
out[0] = x out[0] = x
else: else:
newshape = x.shape[:outdim-1] + (numpy.prod(x.shape[outdim-1:]),) newshape = (x.shape[:outdim - 1] +
#print 'newshape', newshape, x.shape, x.shape (numpy.prod(x.shape[outdim - 1:]),))
out[0] = x.reshape(newshape) out[0] = x.reshape(newshape)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
g_out, = grads g_out, = grads
...@@ -4668,23 +4680,29 @@ class Flatten(Op): ...@@ -4668,23 +4680,29 @@ class Flatten(Op):
return [None] return [None]
return self.make_node(*eval_points).outputs return self.make_node(*eval_points).outputs
def flatten(x, outdim=1): def flatten(x, outdim=1):
return Flatten(outdim)(x) return Flatten(outdim)(x)
class TileGrad(Op): class TileGrad(Op):
"""Calculates the gradient of the Tile Op""" """
Calculates the gradient of the Tile Op.
"""
#this is so weird, I can't think of how to make this a general thing. #this is so weird, I can't think of how to make this a general thing.
def make_node(self, x, reps, g_out): def make_node(self, x, reps, g_out):
return gof.Apply(self, [x, reps, g_out], [x.type()]) return gof.Apply(self, [x, reps, g_out], [x.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, reps, g_out = inp x, reps, g_out = inp
gx, = out gx, = out
xsh = x.shape xsh = x.shape
if len(reps)==2 and reps[1] == 1 and len(x.shape) == 1: if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
gx[0] = numpy.sum(g_out, axis=0) gx[0] = numpy.sum(g_out, axis=0)
else: else:
raise NotImplementedError('x.shape, reps combination not supported', raise NotImplementedError('x.shape, reps combination not'
(x.shape, reps)) 'supported', (x.shape, reps))
tilegrad = TileGrad() tilegrad = TileGrad()
...@@ -4692,43 +4710,56 @@ class Tile(Op): ...@@ -4692,43 +4710,56 @@ class Tile(Op):
""" """
Construct an array by repeating the input x according to reps pattern. Construct an array by repeating the input x according to reps pattern.
Tiles its input according to reps. The len of reps is the number of Tiles its input according to reps. The length of reps is the number of
dimension of x and contains the number of times to tile x in each dimension. dimension of x and contains the number of times to tile x in each
dimension.
:see: `numpy.tile http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html`_ :see: `numpy.tile
<http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html>`_
""" """
def __init__(self, ndim): def __init__(self, ndim):
self.ndim = ndim self.ndim = ndim
def __eq__(self, other): def __eq__(self, other):
return (type(other) is Tile) and (other.ndim == self.ndim) return (type(other) is Tile) and (other.ndim == self.ndim)
def __hash__(self): def __hash__(self):
return hash(Tile) ^ hash(self.ndim) return hash(Tile) ^ hash(self.ndim)
def make_node(self, x, reps): def make_node(self, x, reps):
x = as_tensor_variable(x) x = as_tensor_variable(x)
reps = as_tensor_variable(reps) reps = as_tensor_variable(reps)
return gof.Apply(self, [x, reps], [tensor(x.type.dtype, [False,] * self.ndim)]) return gof.Apply(self, [x, reps], [tensor(x.type.dtype, [False] *
self.ndim)])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, reps = inp x, reps = inp
out, = out_ out, = out_
out[0] = numpy.tile(x, reps) out[0] = numpy.tile(x, reps)
if len(out[0].shape) != self.ndim: if len(out[0].shape) != self.ndim:
raise ValueError('Tile.perform produced incorrect shape') raise ValueError('Tile.perform produced incorrect shape')
def grad(self, inp, grads): def grad(self, inp, grads):
x, reps = inp x, reps = inp
g_out, = grads g_out, = grads
return [tilegrad(x, reps, g_out), None] return [tilegrad(x, reps, g_out), None]
def tile(x, reps, ndim=None): def tile(x, reps, ndim=None):
"""
Tile input array `x` according to `reps`. See the docstring of `numpy.tile`
for details.
TODO: expand this.
"""
if not hasattr(tile, 'op'): if not hasattr(tile, 'op'):
tile.op = {} tile.op = {}
if ndim is None: if ndim is None:
ndim = len(reps) ndim = len(reps)
#backport # backport
#ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work. # ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going
# to work.
if ndim not in tile.op: if ndim not in tile.op:
tile.op[ndim] = Tile(ndim) tile.op[ndim] = Tile(ndim)
return tile.op[ndim](x, reps) return tile.op[ndim](x, reps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论