提交 7a4bedf2 authored 作者: hantek's avatar hantek

add gradient from nlinalg.ExtractDiag

上级 e9169843
...@@ -33,7 +33,6 @@ from theano.compile import Rebroadcast, Shape, shape ...@@ -33,7 +33,6 @@ from theano.compile import Rebroadcast, Shape, shape
# We use these exceptions as well. # We use these exceptions as well.
import theano.scalar.sharedvar import theano.scalar.sharedvar
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.gradient import grad_not_implemented
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
# set up the external interface # set up the external interface
...@@ -6144,9 +6143,18 @@ class ExtractDiag(Op): ...@@ -6144,9 +6143,18 @@ class ExtractDiag(Op):
z[0] = z[0].copy() z[0] = z[0].copy()
def grad(self, inputs, gout): def grad(self, inputs, gout):
"""
The following code is moved from tensor.nlinalg.ExtractDiag, only works
for matrices.
"""
warnings.warn("gradient of theano.tensor.nlinalg.ExtractDiag only"
"works for matrices.")
(x,) = inputs (x,) = inputs
(gz,) = gout (gz,) = gout
return [grad_not_implemented(self, 0, x)] x = theano.tensor.zeros_like(x)
xdiag = theano.tensor.AllocDiag(offset=self.offset)(gz)
return [theano.tensor.set_subtensor(
x[:xdiag.shape[0], :xdiag.shape[1]], xdiag)]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
in_shape, = shapes in_shape, = shapes
......
...@@ -7400,7 +7400,8 @@ class test_diag(unittest.TestCase): ...@@ -7400,7 +7400,8 @@ class test_diag(unittest.TestCase):
f = theano.function([x], g.shape) f = theano.function([x], g.shape)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
if config.mode != 'FAST_COMPILE': if config.mode != 'FAST_COMPILE':
assert sum([isinstance(node.op, AllocDiag) for node in topo]) == 0 assert numpy.sum(
[isinstance(node.op, AllocDiag) for node in topo]) == 0
for shp in [5, 0, 1]: for shp in [5, 0, 1]:
m = rng.rand(shp).astype(self.floatX) m = rng.rand(shp).astype(self.floatX)
assert (f(m) == numpy.diag(m).shape).all() assert (f(m) == numpy.diag(m).shape).all()
...@@ -7410,7 +7411,8 @@ class test_diag(unittest.TestCase): ...@@ -7410,7 +7411,8 @@ class test_diag(unittest.TestCase):
f = theano.function([x], g.shape) f = theano.function([x], g.shape)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
if config.mode != 'FAST_COMPILE': if config.mode != 'FAST_COMPILE':
assert sum([isinstance(node.op, ExtractDiag) for node in topo]) == 0 assert numpy.sum(
[isinstance(node.op, ExtractDiag) for node in topo]) == 0
for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5), for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5),
(1, 0), (0, 1)]: (1, 0), (0, 1)]:
m = rng.rand(*shp).astype(self.floatX) m = rng.rand(*shp).astype(self.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论