提交 03e6e398 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #135 from abergeron/diag_review

Review/tests for ExtractDiag and trace
...@@ -399,40 +399,46 @@ class ExtractDiag(Op): ...@@ -399,40 +399,46 @@ class ExtractDiag(Op):
self.view = view self.view = view
if self.view: if self.view:
self.view_map = {0:[0]} self.view_map = {0:[0]}
self.perform = self.perform_view
else:
self.perform = self.perform_noview
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.view == other.view return type(self) == type(other) and self.view == other.view
def __hash__(self): def __hash__(self):
return hash(type(self))^hash(self.view) return hash(type(self))^hash(self.view)
def make_node(self, _x): def make_node(self, _x):
x = as_tensor_variable(_x) x = as_tensor_variable(_x)
if x.type.ndim != 2: if x.type.ndim != 2:
raise TypeError('ExtractDiag only works on matrices', _x) raise TypeError('ExtractDiag only works on matrices', _x)
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform_noview(self, node, (x,), (z,)):
def perform(self, node, ins, outs):
x, = ins
z, = outs
#for some reason numpy.diag(x) is really slow #for some reason numpy.diag(x) is really slow
N,M = x.shape N,M = x.shape
assert N==M assert N==M
rval = x[0] rval = x[0]
rval.strides = (x.strides[0]+x.strides[1],) rval.strides = (x.strides[0]+x.strides[1],)
z[0] = rval.copy() if self.view:
def perform_view(self, node, (x,), (z,)): z[0] = rval
N,M = x.shape else:
a,b = x.strides z[0] = rval.copy()
assert N==M
rval = x[0]
rval.strides = a+b,
z[0] = rval
def __str__(self): def __str__(self):
return 'ExtractDiag{view=%s}'%self.view return 'ExtractDiag{view=%s}'%self.view
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
return [alloc_diag(g_outputs[0])] return [alloc_diag(g_outputs[0])]
extract_diag = ExtractDiag()
def infer_shape(self, node, shapes):
x_s, = shapes
return [(x_s[0],)]
extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True #TODO: optimization to insert ExtractDiag with view=True
class AllocDiag(Op): class AllocDiag(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -453,8 +459,11 @@ alloc_diag = AllocDiag() ...@@ -453,8 +459,11 @@ alloc_diag = AllocDiag()
def diag(x): def diag(x):
"""Numpy-compatibility method """Numpy-compatibility method
If `x` is a matrix, return its diagonal.
If `x` is a vector return a matrix with it as its diagonal.
For vector `x`, return a zero matrix except for `x` as diagonal. * This method does not support the `k` argument that numpy supports.
""" """
xx = as_tensor_variable(x) xx = as_tensor_variable(x)
if xx.type.ndim == 1: if xx.type.ndim == 1:
...@@ -494,6 +503,7 @@ def trace(X): ...@@ -494,6 +503,7 @@ def trace(X):
""" """
return extract_diag(X).sum() return extract_diag(X).sum()
def spectral_radius_bound(X, log2_exponent): def spectral_radius_bound(X, log2_exponent):
""" """
Returns upper bound on the largest eigenvalue of square symmetrix matrix X. Returns upper bound on the largest eigenvalue of square symmetrix matrix X.
......
...@@ -5,6 +5,10 @@ import numpy ...@@ -5,6 +5,10 @@ import numpy
import theano import theano
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
from theano.tests import unittest_tools as utt
from theano import config
utt.seed_rng()
try: try:
import scipy import scipy
...@@ -19,11 +23,12 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -19,11 +23,12 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
#solve, #solve,
#diag, #diag,
#extract_diag, ExtractDiag,
extract_diag,
#alloc_diag, #alloc_diag,
det, det,
#PSD_hint, #PSD_hint,
#trace, trace,
#spectral_radius_bound #spectral_radius_bound
) )
...@@ -90,3 +95,61 @@ def test_det_grad(): ...@@ -90,3 +95,61 @@ def test_det_grad():
r = rng.randn(5,5) r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random) tensor.verify_grad(det, [r], rng=numpy.random)
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
g = extract_diag(x)
f = theano.function([x], g)
m = rng.rand(3,3).astype(config.floatX)
v = numpy.diag(m)
r = f(m)
# The right diagonal is extracted
assert (r == v).all()
m = rng.rand(2, 3).astype(config.floatX)
ok = False
try:
r = f(m)
except Exception:
ok = True
assert ok
xx = theano.tensor.vector()
ok = False
try:
extract_diag(xx)
except TypeError:
ok = True
assert ok
f = theano.function([x], g.shape)
topo = f.maker.env.toposort()
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
m = rng.rand(3,3).astype(config.floatX)
assert f(m) == 3
# not testing the view=True case since it is not used anywhere.
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
g = trace(x)
f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
xx = theano.tensor.vector()
ok = False
try:
trace(xx)
except TypeError:
ok = True
assert ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论