提交 89c9b628 authored 作者: hantek's avatar hantek

changes according to abergeron

上级 7a4bedf2
......@@ -1078,7 +1078,7 @@ __device__ ga_half atomicExch(ga_half *addr, ga_half val) {
""" % locals()
class GpuExtractDiag(Subtensor):
class GpuExtractDiag(Op):
__props__ = ("offset", "axis1", "axis2", "view")
def __init__(self, offset=0, axis1=0, axis2=1, view=False):
......
......@@ -33,6 +33,7 @@ from theano.compile import Rebroadcast, Shape, shape
# We use these exceptions as well.
import theano.scalar.sharedvar
from theano.gradient import grad_undefined
from theano.gradient import grad_not_implemented
from theano.gradient import DisconnectedType
# set up the external interface
......@@ -6143,18 +6144,21 @@ class ExtractDiag(Op):
z[0] = z[0].copy()
def grad(self, inputs, gout):
"""
The following code is moved from tensor.nlinalg.ExtractDiag, only works
for matrices.
"""
warnings.warn("gradient of theano.tensor.nlinalg.ExtractDiag only"
"works for matrices.")
(x,) = inputs
(gz,) = gout
if x.ndim == 2:
# The following code is moved from tensor.nlinalg.ExtractDiag, only
# works for matrices.
x = theano.tensor.zeros_like(x)
xdiag = theano.tensor.AllocDiag(offset=self.offset)(gz)
return [theano.tensor.set_subtensor(
x[:xdiag.shape[0], :xdiag.shape[1]], xdiag)]
else:
warnings.warn("gradient of theano.tensor.nlinalg.ExtractDiag only"
"works for matrices.")
return [grad_not_implemented(self, 0, x)]
def infer_shape(self, node, shapes):
in_shape, = shapes
......@@ -6204,9 +6208,9 @@ class AllocDiag(Op):
Usage: T.AllocDiag()(x)
`x` should be a tensor vector. The parenthesis in the front should indicate
which main diagonal the vector value goes into. By default it is set to (
which main diagonal the vector value goes into. By default it is set to
`0`, which corresponds to setting the values of x to the main diagonal in
the returned matrix. Currently the gradient is valid only when `offset=0`.
the returned matrix.
Parameters
----------
......@@ -6233,6 +6237,8 @@ class AllocDiag(Op):
def make_node(self, diag):
diag = as_tensor_variable(diag)
if diag.type.ndim != 1:
raise TypeError('data argument must be a vector', diag.type)
return Apply(self, [diag], [matrix(dtype=diag.dtype)])
def perform(self, node, inputs, outputs):
......@@ -6270,11 +6276,11 @@ def diag(v, k=0):
"""
if v.ndim == 1:
return AllocDiag()(v)
return AllocDiag(k)(v)
elif v.ndim >= 2:
return diagonal(v, offset=k)
else:
raise ValueError("Input must has v.dim >= 1.")
raise ValueError("Input must has v.ndim >= 1.")
def stacklists(arg):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论