提交 7a4bedf2 authored 作者: hantek's avatar hantek

add gradient from nlinalg.ExtractDiag

上级 e9169843
......@@ -33,7 +33,6 @@ from theano.compile import Rebroadcast, Shape, shape
# We use these exceptions as well.
import theano.scalar.sharedvar
from theano.gradient import grad_undefined
from theano.gradient import grad_not_implemented
from theano.gradient import DisconnectedType
# set up the external interface
......@@ -6144,9 +6143,18 @@ class ExtractDiag(Op):
z[0] = z[0].copy()
def grad(self, inputs, gout):
"""
The following code is moved from tensor.nlinalg.ExtractDiag, only works
for matrices.
"""
warnings.warn("gradient of theano.tensor.nlinalg.ExtractDiag only"
"works for matrices.")
(x,) = inputs
(gz,) = gout
return [grad_not_implemented(self, 0, x)]
x = theano.tensor.zeros_like(x)
xdiag = theano.tensor.AllocDiag(offset=self.offset)(gz)
return [theano.tensor.set_subtensor(
x[:xdiag.shape[0], :xdiag.shape[1]], xdiag)]
def infer_shape(self, node, shapes):
in_shape, = shapes
......
......@@ -7400,7 +7400,8 @@ class test_diag(unittest.TestCase):
f = theano.function([x], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != 'FAST_COMPILE':
assert sum([isinstance(node.op, AllocDiag) for node in topo]) == 0
assert numpy.sum(
[isinstance(node.op, AllocDiag) for node in topo]) == 0
for shp in [5, 0, 1]:
m = rng.rand(shp).astype(self.floatX)
assert (f(m) == numpy.diag(m).shape).all()
......@@ -7410,7 +7411,8 @@ class test_diag(unittest.TestCase):
f = theano.function([x], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != 'FAST_COMPILE':
assert sum([isinstance(node.op, ExtractDiag) for node in topo]) == 0
assert numpy.sum(
[isinstance(node.op, ExtractDiag) for node in topo]) == 0
for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5),
(1, 0), (0, 1)]:
m = rng.rand(*shp).astype(self.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论