提交 cadbdfb9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

merge

...@@ -1590,7 +1590,6 @@ def one(): ...@@ -1590,7 +1590,6 @@ def one():
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros')) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros'))
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones')) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones'))
@_redefine(elemwise.Elemwise(scal.identity)) @_redefine(elemwise.Elemwise(scal.identity))
def tensor_copy(a): def tensor_copy(a):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
...@@ -2707,87 +2706,217 @@ def tile(x, reps, ndim=None): ...@@ -2707,87 +2706,217 @@ def tile(x, reps, ndim=None):
tile.op = {} tile.op = {}
if ndim is None: if ndim is None:
ndim = len(reps) ndim = len(reps)
#backport #backport
#ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work. #ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work.
if ndim not in tile.op: if ndim not in tile.op:
tile.op[ndim] = Tile(ndim) tile.op[ndim] = Tile(ndim)
return tile.op[ndim](x, reps) return tile.op[ndim](x, reps)
class InversePermutation(Op):
"""Computes the inverse of permutations.
Each row of input should contain a permutation of the first integers. class ARange(Op):
"""Create an array containing evenly spaced values within a given interval.
Parameters and behaviour are the same as numpy.arange().
""" """
def make_node(self, x): def __init__(self, dtype):
x = as_tensor_variable(x) self.dtype = dtype
return Apply(self, [x], [x.type()])
def perform(self, node, (x,), (outs,)): def __eq__(self, other):
if outs[0] is None or outs[0].shape != x.shape: return type(self) == type(other) and self.dtype == other.dtype
outs[0] = numpy.empty_like(x)
for i in numpy.ndindex(x.shape[:-1]):
outs[0][i][x[i]] = numpy.arange(x.shape[-1], dtype=x.dtype)
def grad(self, (x,), (gz,)): def __hash__(self):
return [None] return hash(self.dtype)
def make_node(self, start, stop, step):
start, stop, step = map(as_tensor_variable, (start, stop, step))
assert start.ndim == 0
assert stop.ndim == 0
assert step.ndim == 0
inverse_permutation = InversePermutation() inputs = [start, stop, step]
outputs = [tensor(self.dtype, (False,))]
return Apply(self, inputs, outputs)
class ReorderRowElements(Op): def perform(self, node, (start, stop, step), (out,)):
"""Reorder each row (inner-most dim) of a tensor wrt a permutation. start = start.item()
stop = stop.item()
step = step.item()
out[0] = numpy.arange(start, stop, step, dtype=self.dtype)
The permutation argument (y) will be broadcasted to fit x, then each def grad(self, inputs, (gz,)):
row (vector) of x will be reordered according to the corresponding row return [None] * len(inputs)
of y.
WARNING: x will not be broadcasted to fit y (not implemented yet). _arange = {}
def arange(start, stop=None, step=1, dtype=None):
# If only one argument is provided, it is in fact the "stop" argument,
# and start is 0.
if stop is None:
start, stop = 0, start
start, stop, step = map(as_tensor_variable, (start, stop, step))
# If dtype is not provided, infer it from the other arguments
if dtype is None:
dtype = scal.upcast(start.type.dtype, stop.type.dtype, step.type.dtype)
if dtype not in _arange:
_arange[dtype] = ARange(dtype)
return _arange[dtype](start, stop, step)
class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor.
A permutation will be applied to every row (vector) of the input tensor x.
Depending on the dimensionality of x and the permutation tensor y,
different cases are possible.
If y.ndim = 1, y is a single permutation, that will be applied to every
vector of x. For instance, if x is a matrix, the same permutation will be
applied to each row of x.
If x.ndim = y.ndim, each row of x corresponds to a row of y, containing
a permutation that will be applied to that row. For instance, if x and y
are two matrices, a different permutation will be applied to each row of x.
If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector)
of x will be reordered according to the corresponding row of y. (This is
a generalization of the first case).
If x.ndim = 1, every permutation in y will be applied to x, and the output
will contain all the results.
If x.ndim < y.ndim, x will be broadcasted to fit y, and different
permutations contained in y will be applied to each vector in x. (This is
a generalization of the previous case).
If the "inverse" argument is True, the Op will perform the inverse
permutation instead.
""" """
def make_node(self, x, y): def make_node(self, x, y, inverse):
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
assert y.type.dtype.startswith('int') or y.type.dtype.startswith('uint') inverse = as_tensor_variable(inverse)
# extend y dimension to match x
assert x.type.ndim >= y.type.ndim
y = shape_padleft(y, n_ones=(x.type.ndim - y.type.ndim))
inputlist = [x, y] # y should contain integers
outputlist = [x.type()] assert y.type.dtype.startswith('int') or y.type.dtype.startswith('uint')
# Inverse should be an integer scalar
assert inverse.type.ndim == 0 and\
(inverse.type.dtype.startswith('int') or\
inverse.type.dtype.startswith('uint'))
# Match shapes of x and y
x_dim = x.type.ndim
y_dim = y.type.ndim
if x_dim > y_dim:
y = shape_padleft(y, n_ones=(x_dim - y_dim))
elif x_dim < y_dim:
x = shape_padleft(x, n_ones=(y_dim - x_dim))
# Compute the broadcastable pattern of the output
out_broadcastable = [xb and yb for xb, yb in zip(x.type.broadcastable, y.type.broadcastable)]
out_type = tensor(dtype = x.type.dtype, broadcastable = out_broadcastable)
inputlist = [x, y, inverse]
outputlist = [out_type]
return Apply(self, inputlist, outputlist) return Apply(self, inputlist, outputlist)
def _rec_perform(self, node, x, y, out, curdim): def _rec_perform(self, node, x, y, inverse, out, curdim):
"""Perform the permutation by doing a recursion over the input dimensions.
For every dimension, starting with the leftmost, the right set of
indices is determined (depending if broadcasting or not), then
the function is recursively called on the appropriate subtensors.
The terminal case is reached when the current tensors are vector,
then the permutation contained in y is applied to x.
:param x: The input tensor, on which the permutation is applied
:param y: Tensor containing the permutations to apply
:param out: Tensor storing the output result
:param curdim: Counter of the current depth of recursion
:param inverse: Wether to apply permutations or their inverse
"""
if len(x.shape) == 1: if len(x.shape) == 1:
# Numpy advanced indexing works in this case # Numpy advanced indexing works in this case
out[:] = x[y] if inverse:
out[y] = x[:]
else:
out[:] = x[y]
else: else:
xs0 = x.shape[0] xs0 = x.shape[0]
ys0 = y.shape[0] ys0 = y.shape[0]
if xs0 == ys0: if xs0 == ys0:
for i in range(xs0): for i in range(xs0):
self._rec_perform(node, x[i], y[i], out[i], curdim+1) self._rec_perform(node, x[i], y[i], inverse, out[i], curdim+1)
elif node.inputs[1].type.broadcastable[curdim]: elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
# Broadcast y # Broadcast y
for i in range(xs0): for i in range(xs0):
self._rec_perform(node, x[i], y[0], out[i], curdim+1) self._rec_perform(node, x[i], y[0], inverse, out[i], curdim+1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
# Broadcast x
for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim+1)
else: else:
raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0)) raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
def perform(self, node, (x, y), (outs,)): def perform(self, node, (x, y, inverse), (outs,)):
x_s = x.shape x_s = x.shape
y_s = y.shape y_s = y.shape
assert len(x_s) == len(y_s) assert len(x_s) == len(y_s)
if outs[0] is None or outs[0].shape != x_s: # Make sure the output is big enough
outs[0] = numpy.empty_like(x) out_s = []
for xdim, ydim in zip(x_s, y_s):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
outdim = ydim
elif ydim == 1:
outdim = xdim
else:
raise ValueError('Dimension mismatch: %s, %s' % (xdim, ydim))
out_s.append(outdim)
if outs[0] is None or outs[0].shape != out_s:
outs[0] = numpy.empty(out_s, dtype=x.dtype)
self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
def grad(self, (x, y, inverse), (gz,)):
# First, compute the gradient wrt the broadcasted x.
# If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz.
gx = permute_row_elements(gz, y, eq(inverse, 0))
# If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as
# broadcastable)
broadcasted_dims = [dim for dim in range(gz.type.ndim)\
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim]]
gx = Sum(axis = broadcasted_dims)(gx)
# Sum(...) removed the dimensions in broadcasted_dims,
# so we need to put them back.
newdims = []
i = 0
for dim in range(gz.type.ndim):
if dim in broadcasted_dims:
newdims.append('x')
else:
newdims.append(i)
i += 1
self._rec_perform(node, x, y, outs[0], curdim=0) gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
return [gx, None, None]
def grad(self, (x, y), (gz,)): _permute_row_elements = PermuteRowElements()
gx = reorder_row_elements(gz, inverse_permutation(y)) def permute_row_elements(x, y, inverse=0):
return [gx, None] return _permute_row_elements(x, y, inverse)
reorder_row_elements = ReorderRowElements() def inverse_permutation(perm):
"""Computes the inverse of permutations.
Each row of input should contain a permutation of the first integers.
"""
return permute_row_elements(arange(perm.shape[-1]), perm, inverse=True)
######################### #########################
......
...@@ -6,7 +6,7 @@ import numpy ...@@ -6,7 +6,7 @@ import numpy
from theano.compile import module, In, Component from theano.compile import module, In, Component
from theano.gof import Container from theano.gof import Container
from theano.tensor import raw_random, reorder_row_elements from theano.tensor import raw_random, permute_row_elements
class RandomStreamsInstance(object): class RandomStreamsInstance(object):
"""RandomStreamsInstance""" """RandomStreamsInstance"""
...@@ -192,7 +192,7 @@ class RandomStreams(Component): ...@@ -192,7 +192,7 @@ class RandomStreams(Component):
def shuffle_row_elements(self, input): def shuffle_row_elements(self, input):
"""Return a variable with every row (rightmost index) shuffled""" """Return a variable with every row (rightmost index) shuffled"""
perm = self.permutation(input.ndim-1, input.shape[:-1], input.shape[-1]) perm = self.permutation(input.ndim-1, input.shape[:-1], input.shape[-1])
shuffled = reorder_row_elements(input, perm) shuffled = permute_row_elements(input, perm)
return shuffled return shuffled
...@@ -1790,6 +1790,133 @@ def test_tile(): ...@@ -1790,6 +1790,133 @@ def test_tile():
print >> sys.stderr, "WARNING: No testcase for Tile" print >> sys.stderr, "WARNING: No testcase for Tile"
pass pass
class TestARange(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_Op_integers(self):
"""Test behaviour of ARange Op on integer inputs"""
start, stop, step = iscalars('start', 'stop', 'step')
out = ARange(start.type.dtype)(start, stop, step)
f = function([start, stop, step], out)
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4))
assert numpy.all(f(-5,1,1) == numpy.arange(-5,1,1))
assert numpy.all(f(10,2,-2) == numpy.arange(10,2,-2))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2))
assert numpy.all(f(0,0,1) == numpy.arange(0,0,1))
def test_integers(self):
"""Test arange constructor, on integer outputs"""
start, stop, step = iscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4))
assert numpy.all(f(-5,1,1) == numpy.arange(-5,1,1))
assert numpy.all(f(10,2,-2) == numpy.arange(10,2,-2))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2))
assert numpy.all(f(0,0,1) == numpy.arange(0,0,1))
def test_float32(self):
"""Test arange constructor, on integer outputs"""
start, stop, step = fscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1, dtype=start.type.dtype))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4, dtype=start.type.dtype))
assert numpy.all(f(-5,1.1,1.2) == numpy.arange(-5,1.1,1.2, dtype=start.type.dtype))
assert numpy.all(f(1.3,2,-2.1) == numpy.arange(1.3,2,-2.1, dtype=start.type.dtype))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2, dtype=start.type.dtype))
def test_float64(self):
"""Test arange constructor, on integer outputs"""
start, stop, step = dscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1, dtype=start.type.dtype))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4, dtype=start.type.dtype))
assert numpy.all(f(-5,1.1,1.2) == numpy.arange(-5,1.1,1.2, dtype=start.type.dtype))
assert numpy.all(f(1.3,2,-2.1) == numpy.arange(1.3,2,-2.1, dtype=start.type.dtype))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2, dtype=start.type.dtype))
def test_default_step(self):
"""Test that arange constructor uses the correct default step"""
start, stop = iscalars('start', 'stop')
out = arange(start, stop)
f = function([start, stop], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5) == numpy.arange(0,5))
assert numpy.all(f(-5,1) == numpy.arange(-5,1))
assert numpy.all(f(0,0) == numpy.arange(0,0))
dstart, dstop = dscalars('start', 'stop')
dout = arange(dstart, dstop)
df = function([dstart, dstop], dout)
assert dout.dtype == dstart.type.dtype
print df(0.2, 5.3)
print numpy.arange(0.2, 5.3)
assert numpy.all(df(0.2, 5.3) == numpy.arange(0.2, 5.3))
assert numpy.all(df(0.8, 5.3) == numpy.arange(0.8, 5.3))
assert numpy.all(df(-0.7, 5.3) == numpy.arange(-0.7, 5.3))
def test_default_start(self):
"""Test that arange constructor uses the correct default start"""
stop = iscalar('stop')
out = arange(stop)
f = function([stop], out)
assert out.dtype == stop.type.dtype
assert numpy.all(f(8) == numpy.arange(8))
assert numpy.all(f(-2) == numpy.arange(-2))
fstop = fscalar('stop')
fout = arange(fstop)
ff = function([fstop], fout)
assert fout.dtype == fstop.type.dtype
assert numpy.all(ff(0.2) == numpy.arange(0.2))
assert numpy.all(ff(-0.7) == numpy.arange(-0.7))
assert numpy.all(ff(8.5) == numpy.arange(8.5))
def test_upcast(self):
"""Test that arange compute output type adequately"""
assert arange(iscalar()).dtype == iscalar().dtype
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
def test_dtype_cache(self):
"""Checks that the same Op is returned on repeated calls to arange
using the same dtype, but not for different dtypes."""
start, stop, step = iscalars('start', 'stop', 'step')
out1 = arange(start, stop, step)
out2 = arange(start, stop, step, dtype=start.type.dtype)
out3 = arange(start, stop, 2., dtype=start.type.dtype)
out4 = arange(start, stop, 2.)
assert out1.owner.op is out2.owner.op
assert out2.owner.op is out3.owner.op
assert out3.owner.op is not out4.owner.op
class TestInversePermutation(unittest.TestCase): class TestInversePermutation(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
...@@ -1812,7 +1939,7 @@ class TestInversePermutation(unittest.TestCase): ...@@ -1812,7 +1939,7 @@ class TestInversePermutation(unittest.TestCase):
assert numpy.all(inv_val[p_val] == numpy.arange(10)) assert numpy.all(inv_val[p_val] == numpy.arange(10))
def test_dim2(self): def test_dim2(self):
"""Test the inversion of several permutation at a time""" """Test the inversion of several permutations at a time"""
# Each row of p is a different permutation to inverse # Each row of p is a different permutation to inverse
p = imatrix() p = imatrix()
inv = inverse_permutation(p) inv = inverse_permutation(p)
...@@ -1832,65 +1959,65 @@ class TestInversePermutation(unittest.TestCase): ...@@ -1832,65 +1959,65 @@ class TestInversePermutation(unittest.TestCase):
assert numpy.all(i_row[p_row] == numpy.arange(10)) assert numpy.all(i_row[p_row] == numpy.arange(10))
class TestReorderRowElements(unittest.TestCase): class TestPermuteRowElements(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def test_1_1(self): def test_1_1(self):
"""Test ReorderRowElements(vector, vector)""" """Test PermuteRowElements(vector, vector)"""
input = vector() input = vector()
p = ivector() p = ivector()
out = reorder_row_elements(input, p) out = permute_row_elements(input, p)
reorder = function([input, p], out) permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(5,)) input_val = rng.uniform(size=(5,))
p_val = rng.permutation(5) p_val = rng.permutation(5)
out_val = reorder(input_val, p_val) out_val = permute(input_val, p_val)
# Should be equivalent to advanced indexing # Should be equivalent to advanced indexing
out_bis = input_val[p_val] out_bis = input_val[p_val]
assert numpy.all(out_val == out_bis) assert numpy.all(out_val == out_bis)
# Verify gradient # Verify gradient
def reorder_fixed(s_input): def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val""" """Auxiliary op defined to get rid of gradient wrt p_val"""
return reorder_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
def test_2_1(self): def test_2_1(self):
"""Test broadcasting in ReorderRowElements(matrix, vector)""" """Test broadcasting in PermuteRowElements(matrix, vector)"""
input = matrix() input = matrix()
p = ivector() p = ivector()
out = reorder_row_elements(input, p) out = permute_row_elements(input, p)
reorder = function([input, p], out) permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(3,5)) input_val = rng.uniform(size=(3,5))
p_val = rng.permutation(5) p_val = rng.permutation(5)
out_val = reorder(input_val, p_val) out_val = permute(input_val, p_val)
# The same permutation should be applied to every row of the input matrix. # The same permutation should be applied to every row of the input matrix.
out_bis = numpy.asarray([row[p_val] for row in input_val]) out_bis = numpy.asarray([row[p_val] for row in input_val])
assert numpy.all(out_val == out_bis) assert numpy.all(out_val == out_bis)
# Verify gradient # Verify gradient
def reorder_fixed(s_input): def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val""" """Auxiliary op defined to get rid of gradient wrt p_val"""
return reorder_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
def test_2_2(self): def test_2_2(self):
"""Test ReorderRowElements(matrix, matrix)""" """Test PermuteRowElements(matrix, matrix)"""
input = matrix() input = matrix()
p = imatrix() p = imatrix()
out = reorder_row_elements(input, p) out = permute_row_elements(input, p)
reorder = function([input, p], out) permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(3,5)) input_val = rng.uniform(size=(3,5))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)]) p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = reorder(input_val, p_val) out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to the corresponding # Each row of p contains a permutation to apply to the corresponding
# row of input # row of input
...@@ -1898,10 +2025,60 @@ class TestReorderRowElements(unittest.TestCase): ...@@ -1898,10 +2025,60 @@ class TestReorderRowElements(unittest.TestCase):
assert numpy.all(out_val == out_bis) assert numpy.all(out_val == out_bis)
# Verify gradient # Verify gradient
def reorder_fixed(s_input): def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val"""
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
def test_1_2(self):
"""Test PermuteRowElements(vector, matrix)
Different permutations will be applied to the same input vector"""
input = vector()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(5,))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to the input vector
out_bis = numpy.asarray([input_val[p_row] for p_row in p_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val""" """Auxiliary op defined to get rid of gradient wrt p_val"""
return reorder_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
def test_3b_2(self):
"""Test permute_row_elements on a more complex broadcasting pattern:
input.type.broadcastable = (False, True, False),
p.type.broadcastable = (False, False)."""
input = TensorType('float64', (False, True, False))()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(4,1,5))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to each row
# of the input tensor
out_bis = numpy.asarray([[in_mat[0,p_row] for p_row in p_val] for in_mat in input_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val"""
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
class test_tensordot(unittest.TestCase): class test_tensordot(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论