提交 66ba51fe authored 作者: abalkin's avatar abalkin

WIP: Restricted offset and axes arguments to constants and implemented infer_shape.

上级 6c67891a
...@@ -7191,6 +7191,11 @@ class Diagonal(Op): ...@@ -7191,6 +7191,11 @@ class Diagonal(Op):
:return: A vector representing the diagonal elements. :return: A vector representing the diagonal elements.
""" """
def __init__(self, offset=0, axis1=0, axis2=1):
self.offset = offset
self.axis1 = axis1
self.axis2 = axis2
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -7198,22 +7203,28 @@ class Diagonal(Op): ...@@ -7198,22 +7203,28 @@ class Diagonal(Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x, offset=0, axis1=0, axis2=1): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim >= 2 assert x.ndim >= 2
offset, axis1, axis2 = map(scal.as_scalar, (offset, axis1, axis2)) return Apply(self, [x], [tensor(dtype=x.dtype,
return Apply(self, [x, offset, axis1, axis2], [tensor(dtype=x.dtype, broadcastable=[False] * (x.ndim -1))])
broadcastable=[False] * (x.ndim -1))])
def perform(self, node, (x, off, ax1, ax2), (z,)): def perform(self, node, (x,), (z,)):
z[0] = x.diagonal(off, ax1, ax2) z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [square_diagonal(gz)] return [square_diagonal(gz)]
def infer_shape(self, nodes, shapes): def infer_shape(self, node, shapes):
return [(minimum(*shapes[0]), )] xdims = list(shapes[0])
d0 = minimum(xdims[self.axis1], xdims[self.axis2])
xdims = [d for i,d in enumerate(shapes[0])
if i not in (self.axis1, self.axis2)]
xdims.append(d0)
return [tuple(xdims)]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
diagonal = Diagonal()
def diagonal(a, offset=0, axis1=0, axis2=1):
return Diagonal(offset, axis1, axis2)(a)
...@@ -40,7 +40,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -40,7 +40,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor, switch) itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6542,6 +6542,12 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6542,6 +6542,12 @@ class TestInferShape(utt.InferShapeTester):
[Eye()(aiscal, biscal, ciscal)], [Eye()(aiscal, biscal, ciscal)],
[3, 5, 0], Eye) [3, 5, 0], Eye)
# Diagonal
atens3 = tensor3()
atens3_val = rand(4, 5, 3)
atens3_diag = Diagonal()(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
# Shape # Shape
# 'opt.Makevector' precludes optimizer from disentangling # 'opt.Makevector' precludes optimizer from disentangling
# elements of shape # elements of shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论