提交 f934d22e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

PEP8-ify ExtractDiag and associated tests functions

上级 bdf90d1c
...@@ -394,15 +394,19 @@ class ExtractDiag(Op): ...@@ -394,15 +394,19 @@ class ExtractDiag(Op):
self.view = view self.view = view
if self.view: if self.view:
self.view_map = {0:[0]} self.view_map = {0:[0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.view == other.view return type(self) == type(other) and self.view == other.view
def __hash__(self): def __hash__(self):
return hash(type(self))^hash(self.view) return hash(type(self))^hash(self.view)
def make_node(self, _x): def make_node(self, _x):
x = as_tensor_variable(_x) x = as_tensor_variable(_x)
if x.type.ndim != 2: if x.type.ndim != 2:
raise TypeError('ExtractDiag only works on matrices', _x) raise TypeError('ExtractDiag only works on matrices', _x)
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform(self, node, ins, outs): def perform(self, node, ins, outs):
x, = ins x, = ins
z, = outs z, = outs
...@@ -415,14 +419,17 @@ class ExtractDiag(Op): ...@@ -415,14 +419,17 @@ class ExtractDiag(Op):
z[0] = rval z[0] = rval
else: else:
z[0] = rval.copy() z[0] = rval.copy()
def __str__(self): def __str__(self):
return 'ExtractDiag{view=%s}'%self.view return 'ExtractDiag{view=%s}'%self.view
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
return [alloc_diag(g_outputs[0])] return [alloc_diag(g_outputs[0])]
extract_diag = ExtractDiag()
extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True #TODO: optimization to insert ExtractDiag with view=True
class AllocDiag(Op): class AllocDiag(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -486,6 +493,7 @@ def trace(X): ...@@ -486,6 +493,7 @@ def trace(X):
""" """
return extract_diag(X).sum() return extract_diag(X).sum()
def spectral_radius_bound(X, log2_exponent): def spectral_radius_bound(X, log2_exponent):
""" """
Returns upper bound on the largest eigenvalue of square symmetrix matrix X. Returns upper bound on the largest eigenvalue of square symmetrix matrix X.
......
...@@ -93,6 +93,7 @@ def test_det_grad(): ...@@ -93,6 +93,7 @@ def test_det_grad():
r = rng.randn(5,5) r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random) tensor.verify_grad(det, [r], rng=numpy.random)
def test_extract_diag(): def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix() x = theano.tensor.matrix()
...@@ -122,6 +123,7 @@ def test_extract_diag(): ...@@ -122,6 +123,7 @@ def test_extract_diag():
assert ok assert ok
# not testing the view=True case since it is not used anywhere. # not testing the view=True case since it is not used anywhere.
def test_trace(): def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix() x = theano.tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论