提交 03e6e398 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #135 from abergeron/diag_review

Review/tests for ExtractDiag and trace
......@@ -399,40 +399,46 @@ class ExtractDiag(Op):
self.view = view
if self.view:
self.view_map = {0:[0]}
self.perform = self.perform_view
else:
self.perform = self.perform_noview
def __eq__(self, other):
return type(self) == type(other) and self.view == other.view
def __hash__(self):
return hash(type(self))^hash(self.view)
def make_node(self, _x):
x = as_tensor_variable(_x)
if x.type.ndim != 2:
raise TypeError('ExtractDiag only works on matrices', _x)
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform_noview(self, node, (x,), (z,)):
def perform(self, node, ins, outs):
x, = ins
z, = outs
#for some reason numpy.diag(x) is really slow
N,M = x.shape
assert N==M
rval = x[0]
rval.strides = (x.strides[0]+x.strides[1],)
z[0] = rval.copy()
def perform_view(self, node, (x,), (z,)):
N,M = x.shape
a,b = x.strides
assert N==M
rval = x[0]
rval.strides = a+b,
z[0] = rval
if self.view:
z[0] = rval
else:
z[0] = rval.copy()
def __str__(self):
return 'ExtractDiag{view=%s}'%self.view
def grad(self, inputs, g_outputs):
return [alloc_diag(g_outputs[0])]
extract_diag = ExtractDiag()
def infer_shape(self, node, shapes):
x_s, = shapes
return [(x_s[0],)]
extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True
class AllocDiag(Op):
def __eq__(self, other):
return type(self) == type(other)
......@@ -453,8 +459,11 @@ alloc_diag = AllocDiag()
def diag(x):
"""Numpy-compatibility method
If `x` is a matrix, return its diagonal.
If `x` is a vector return a matrix with it as its diagonal.
For vector `x`, return a zero matrix except for `x` as diagonal.
* This method does not support the `k` argument that numpy supports.
"""
xx = as_tensor_variable(x)
if xx.type.ndim == 1:
......@@ -494,6 +503,7 @@ def trace(X):
"""
return extract_diag(X).sum()
def spectral_radius_bound(X, log2_exponent):
"""
Returns upper bound on the largest eigenvalue of square symmetrix matrix X.
......
......@@ -5,6 +5,10 @@ import numpy
import theano
from theano import tensor, function
from theano.tensor.basic import _allclose
from theano.tests import unittest_tools as utt
from theano import config
utt.seed_rng()
try:
import scipy
......@@ -19,11 +23,12 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse,
#solve,
#diag,
#extract_diag,
ExtractDiag,
extract_diag,
#alloc_diag,
det,
#PSD_hint,
#trace,
trace,
#spectral_radius_bound
)
......@@ -90,3 +95,61 @@ def test_det_grad():
r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random)
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
g = extract_diag(x)
f = theano.function([x], g)
m = rng.rand(3,3).astype(config.floatX)
v = numpy.diag(m)
r = f(m)
# The right diagonal is extracted
assert (r == v).all()
m = rng.rand(2, 3).astype(config.floatX)
ok = False
try:
r = f(m)
except Exception:
ok = True
assert ok
xx = theano.tensor.vector()
ok = False
try:
extract_diag(xx)
except TypeError:
ok = True
assert ok
f = theano.function([x], g.shape)
topo = f.maker.env.toposort()
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
m = rng.rand(3,3).astype(config.floatX)
assert f(m) == 3
# not testing the view=True case since it is not used anywhere.
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
g = trace(x)
f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
xx = theano.tensor.vector()
ok = False
try:
trace(xx)
except TypeError:
ok = True
assert ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论