提交 f89ec97d authored 作者: Shawn Tan's avatar Shawn Tan

Subclassed GpuAllocDiag for infer_shape and modified grad to take self.axis1 and…

Subclassed GpuAllocDiag for infer_shape and modified grad to take self.axis1 and self.axis2 into account
上级 f766415c
......@@ -9,6 +9,7 @@ from theano.gof import ParamsType
from theano.gradient import grad_not_implemented
import theano.tensor as T
from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list
from theano.tensor import AllocDiag
from theano.scalar import bool as bool_t, int32 as int_t, uint32 as size_t
try:
......@@ -1356,7 +1357,7 @@ class GpuExtractDiag(Op):
return [tuple(out_shape)]
class GpuAllocDiag(Op):
class GpuAllocDiag(AllocDiag):
__props__ = ("offset", "axis1", "axis2")
def __init__(self, offset=0, axis1=0, axis2=1):
......@@ -1384,6 +1385,7 @@ class GpuAllocDiag(Op):
axis2 = np.maximum(self.axis1, self.axis2)
offset = self.offset
# Initialise a buffer the same size as the output
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result_buffer_shape = ((np.prod(x.shape[:-1]).astype(np.int64),) +
((x.shape[-1] + abs(offset)) ** 2,))
......@@ -1391,6 +1393,7 @@ class GpuAllocDiag(Op):
dtype=x.dtype,
context=x.context)
# Slice out a view of the diagonals
if offset != 0:
row_size = x.shape[-1] + abs(offset)
if offset >= 0:
......@@ -1401,18 +1404,15 @@ class GpuAllocDiag(Op):
end_flattened_offset = abs(offset)
diag_view = result_buffer[:, start_flattened_offset:-end_flattened_offset:row_size + 1]
# print("offset", offset)
# print("buffer shape:", result_buffer.shape)
# print("result_buffer[%d:%d:%d]" % (start_flattened_offset, -end_flattened_offset, row_size + 1), diag_view.shape)
# print("input_shape:", x.shape)
else:
diag_view = result_buffer[:, ::x.shape[-1] + 1]
# Fill view with flattened array of diagonals
diag_view[:] = x.reshape(diag_view.shape)[:]
# Unflatten buffer into output size
result = result_buffer.reshape(result_shape)
# print(result)
# Fill in final 2 axes with x
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
......@@ -1425,8 +1425,4 @@ class GpuAllocDiag(Op):
def grad(self, inputs, gout):
(gz,) = gout
return [GpuExtractDiag(offset=self.offset, axis1=0, axis2=1)(gz)]
def infer_shape(self, node, shapes):
dim = shapes[0][0] + abs(self.offset)
return [[dim, dim]]
return [GpuExtractDiag(offset=self.offset, axis1=self.axis1, axis2=self.axis2)(gz)]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论