提交 55f0aa5d authored 作者: abalkin's avatar abalkin

WIP: Started implementation of diag().

上级 3de5ca32
...@@ -7238,3 +7238,39 @@ class Diagonal(Op): ...@@ -7238,3 +7238,39 @@ class Diagonal(Op):
def diagonal(a, offset=0, axis1=0, axis2=1): def diagonal(a, offset=0, axis1=0, axis2=1):
return Diagonal(offset, axis1, axis2)(a) return Diagonal(offset, axis1, axis2)(a)
class Diag(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, diag):
diag = as_tensor_variable(diag)
if diag.type.ndim != 1:
raise TypeError('data argument must be a vector', diag.type)
return Apply(self, [diag], [matrix(dtype=diag.dtype)])
def perform(self, node, inputs, (z,)):
z[0] = numpy.diag(inputs[0])
def grad(self, inputs, (gz,)):
return [diagonal(gz)]
def infer_shape(self, nodes, shapes):
return [(shapes[0][0],) * 2]
def __str__(self):
return self.__class__.__name__
def diag(v, k=0):
if v.ndim == 1:
assert k == 0, "diagonals other than main are not implemented"
return Diag()(v)
elif v.ndim == 2:
return diagonal(v, k)
else:
raise ValueError("Input must be 1- or 2-d.")
...@@ -40,7 +40,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -40,7 +40,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal) itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6557,6 +6557,13 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6557,6 +6557,13 @@ class TestInferShape(utt.InferShapeTester):
atens3_diag = Diagonal(1,0,2)(atens3) atens3_diag = Diagonal(1,0,2)(atens3)
self._compile_and_check([atens3], [atens3_diag], self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal) [atens3_val], Diagonal)
# Diag
advec = dvector()
advec_val = rand(4)
self._compile_and_check([advec], [Diag()(advec)],
[advec_val], Diag)
# Shape # Shape
# 'opt.Makevector' precludes optimizer from disentangling # 'opt.Makevector' precludes optimizer from disentangling
# elements of shape # elements of shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论