提交 d50c1eba authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins 提交者: Frederic

Reviewed and added tests for AllocDiag and fixed bugs ExtractDiag:

* gradient now supports non-diagonal inputs * ExtractDiag won't crash if input has some 0-dimensions
上级 9111d00c
...@@ -561,7 +561,7 @@ solve = Solve() # general solve ...@@ -561,7 +561,7 @@ solve = Solve() # general solve
#TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten) #TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten)
class ExtractDiag(Op): class ExtractDiag(Op):
""" Return the diagonal of matrix """ """ Return the diagonal of a matrix """
def __init__(self, view=False): def __init__(self, view=False):
self.view = view self.view = view
if self.view: if self.view:
...@@ -580,10 +580,14 @@ class ExtractDiag(Op): ...@@ -580,10 +580,14 @@ class ExtractDiag(Op):
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform(self, node, ins, outs): def perform(self, node, ins, outs):
""" For some reason numpy.diag(x) is really slow, so we implemented our own """
x, = ins x, = ins
z, = outs z, = outs
#for some reason numpy.diag(x) is really slow
N,M = x.shape # zero-dimensional matrices ...
if x.shape[0] == 0 or x.shape[1] == 0:
z[0] = x
return
if x.shape[0] < x.shape [1]: if x.shape[0] < x.shape [1]:
rval = x[:,0] rval = x[:,0]
...@@ -600,7 +604,9 @@ class ExtractDiag(Op): ...@@ -600,7 +604,9 @@ class ExtractDiag(Op):
return 'ExtractDiag{view=%s}'%self.view return 'ExtractDiag{view=%s}'%self.view
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
return [alloc_diag(g_outputs[0])] x = tensor.zeros_like(inputs[0])
xdiag = alloc_diag(g_outputs[0])
return [tensor.set_subtensor(x[:xdiag.shape[0], :xdiag.shape[1]], xdiag, inplace=True)]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
x_s, = shapes x_s, = shapes
...@@ -610,23 +616,34 @@ class ExtractDiag(Op): ...@@ -610,23 +616,34 @@ class ExtractDiag(Op):
extract_diag = ExtractDiag() extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True #TODO: optimization to insert ExtractDiag with view=True
class AllocDiag(Op): class AllocDiag(Op):
"""
Allocates a square matrix with the given vector as its diagonal.
"""
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, _x): def make_node(self, _x):
x = as_tensor_variable(_x) x = as_tensor_variable(_x)
if x.type.ndim != 1: if x.type.ndim != 1:
raise TypeError('AllocDiag only works on vectors', _x) raise TypeError('AllocDiag only works on vectors', _x)
return Apply(self, [x], [tensor.matrix(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.matrix(dtype=x.type.dtype)])
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
return [extract_diag(g_outputs[0])] return [extract_diag(g_outputs[0])]
def perform(self, node, (x,), (z,)): def perform(self, node, (x,), (z,)):
if x.ndim != 1: if x.ndim != 1:
raise TypeError(x) raise TypeError(x)
z[0] = numpy.diag(x) z[0] = numpy.diag(x)
def infer_shape(self, node, shapes):
x_s, = shapes
return [(x_s[0],x_s[0])]
alloc_diag = AllocDiag() alloc_diag = AllocDiag()
def diag(x): def diag(x):
......
...@@ -15,10 +15,11 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -15,10 +15,11 @@ from theano.sandbox.linalg.ops import (cholesky,
CholeskyGrad, CholeskyGrad,
matrix_inverse, matrix_inverse,
#solve, #solve,
#diag, diag,
ExtractDiag, ExtractDiag,
extract_diag, extract_diag,
#alloc_diag, AllocDiag,
alloc_diag,
det, det,
#PSD_hint, #PSD_hint,
trace, trace,
...@@ -227,6 +228,62 @@ def test_det_shape(): ...@@ -227,6 +228,62 @@ def test_det_shape():
f_shape = theano.function([x], det(x).shape) f_shape = theano.function([x], det(x).shape)
assert numpy.all(f(r).shape == f_shape(r)) assert numpy.all(f(r).shape == f_shape(r))
def test_alloc_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.vector()
g = alloc_diag(x)
f = theano.function([x], g)
# test "normal" scenario (5x5 matrix) and special cases of 0x0 and 1x1
for shp in [5, 0, 1]:
m = rng.rand(shp).astype(config.floatX)
v = numpy.diag(m)
r = f(m)
# The right diagonal is extracted
assert (r == v).all()
# Test we accept only vectors
xx = theano.tensor.matrix()
ok = False
try:
alloc_diag(xx)
except TypeError:
ok = True
assert ok
# Test infer_shape
f = theano.function([x], g.shape)
topo = f.maker.env.toposort()
if config.mode != 'FAST_COMPILE':
assert sum([node.op.__class__ == AllocDiag for node in topo]) == 0
for shp in [5, 0, 1]:
m = rng.rand(shp).astype(config.floatX)
assert (f(m) == m.shape).all()
def test_alloc_diag_grad():
rng = numpy.random.RandomState(utt.fetch_seed())
x = rng.rand(5)
tensor.verify_grad(alloc_diag, [x], rng=rng)
def test_diag():
# test that it builds a matrix with given diagonal when using vector inputs
x = theano.tensor.vector()
y = diag(x)
assert y.owner.op.__class__ == AllocDiag
# test that it extracts the diagonal when using matrix input
x = theano.tensor.matrix()
y = extract_diag(x)
assert y.owner.op.__class__ == ExtractDiag
# other types should raise error
x = theano.tensor.tensor3()
ok = False
try:
y = extract_diag(x)
except TypeError:
ok = True
assert ok
def test_extract_diag(): def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
...@@ -234,7 +291,7 @@ def test_extract_diag(): ...@@ -234,7 +291,7 @@ def test_extract_diag():
g = extract_diag(x) g = extract_diag(x)
f = theano.function([x], g) f = theano.function([x], g)
for shp in [(2, 3), (3, 2), (3, 3)]: for shp in [(2, 3), (3, 2), (3, 3), (1,1), (0,0)]:
m = rng.rand(*shp).astype(config.floatX) m = rng.rand(*shp).astype(config.floatX)
v = numpy.diag(m) v = numpy.diag(m)
r = f(m) r = f(m)
...@@ -259,8 +316,12 @@ def test_extract_diag(): ...@@ -259,8 +316,12 @@ def test_extract_diag():
m = rng.rand(*shp).astype(config.floatX) m = rng.rand(*shp).astype(config.floatX)
assert f(m) == min(shp) assert f(m) == min(shp)
# not testing the view=True case since it is not used anywhere. def test_extract_diag_grad():
rng = numpy.random.RandomState(utt.fetch_seed())
x = rng.rand(5,4)
tensor.verify_grad(extract_diag, [x], rng=rng)
# not testing the view=True case since it is not used anywhere.
def test_trace(): def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论