提交 f934d22e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

PEP8-ify ExtractDiag and associated tests functions

上级 bdf90d1c
......@@ -394,15 +394,19 @@ class ExtractDiag(Op):
self.view = view
if self.view:
self.view_map = {0:[0]}
def __eq__(self, other):
return type(self) == type(other) and self.view == other.view
def __hash__(self):
return hash(type(self))^hash(self.view)
def make_node(self, _x):
x = as_tensor_variable(_x)
if x.type.ndim != 2:
raise TypeError('ExtractDiag only works on matrices', _x)
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform(self, node, ins, outs):
x, = ins
z, = outs
......@@ -415,14 +419,17 @@ class ExtractDiag(Op):
z[0] = rval
else:
z[0] = rval.copy()
def __str__(self):
return 'ExtractDiag{view=%s}'%self.view
def grad(self, inputs, g_outputs):
return [alloc_diag(g_outputs[0])]
extract_diag = ExtractDiag()
extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True
class AllocDiag(Op):
def __eq__(self, other):
return type(self) == type(other)
......@@ -486,6 +493,7 @@ def trace(X):
"""
return extract_diag(X).sum()
def spectral_radius_bound(X, log2_exponent):
"""
Returns upper bound on the largest eigenvalue of square symmetrix matrix X.
......
......@@ -93,6 +93,7 @@ def test_det_grad():
r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random)
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
......@@ -122,6 +123,7 @@ def test_extract_diag():
assert ok
# not testing the view=True case since it is not used anywhere.
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论