提交 e3726097 authored 作者: Shawn Tan's avatar Shawn Tan

Generalised AllocDiag working with existing tests.

offsets not implemented.
上级 cd6ca858
......@@ -6545,20 +6545,18 @@ class AllocDiag(Op):
raise ValueError('AllocDiag needs an input with 1 or more '
'dimensions', diag.type)
return Apply(
self, [diag],[
diag.type.__class__(
self, [diag],
[diag.type.__class__(
dtype=diag.dtype,
broadcastable=[False] * (diag.ndim + 1)
)()
])
broadcastable=[False] * (diag.ndim + 1))()]
)
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
axis1 = min(axis1, axis2)
axis2 = max(axis1, axis2)
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1],) * 2
......@@ -6570,7 +6568,7 @@ class AllocDiag(Op):
# Fill in final 2 axes with x
result[diagonal_slice] = x
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = range(len(x.shape[:-1]))
axes = axes[:axis1] + [axes[-1] + 1] + axes[axis1:]
......@@ -6579,7 +6577,6 @@ class AllocDiag(Op):
z[0] = result
def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论