提交 e3726097 authored 作者: Shawn Tan's avatar Shawn Tan

Generalised AllocDiag working with existing tests.

offsets not implemented.
上级 cd6ca858
......@@ -6545,41 +6545,38 @@ class AllocDiag(Op):
raise ValueError('AllocDiag needs an input with 1 or more '
'dimensions', diag.type)
return Apply(
self, [diag],[
diag.type.__class__(
self, [diag],
[diag.type.__class__(
dtype=diag.dtype,
broadcastable=[False] * (diag.ndim + 1)
)()
])
broadcastable=[False] * (diag.ndim + 1))()]
)
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
axis1 = min(axis1, axis2)
axis2 = max(axis1, axis2)
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1],) * 2
result = np.zeros(result_shape, dtype=x.dtype)
# Create slice for diagonal in final 2 axes
diagonal_slice = ((len(result_shape) - 2) * [slice(None)] +
diagonal_slice = ((len(result_shape) - 2) * [slice(None)] +
[np.arange(x.shape[-1])] * 2)
# Fill in final 2 axes with x
result[diagonal_slice] = x
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = range(len(x.shape[:-1]))
axes = axes[:axis1] + [axes[-1] + 1] + axes[axis1:]
axes = axes[:axis2] + [axes[-1] + 2] + axes[axis2:]
result = result.transpose(axes)
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = range(len(x.shape[:-1]))
axes = axes[:axis1] + [axes[-1] + 1] + axes[axis1:]
axes = axes[:axis2] + [axes[-1] + 2] + axes[axis2:]
result = result.transpose(axes)
z[0] = result
def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论