提交 e3726097 authored 作者: Shawn Tan's avatar Shawn Tan

Generalised AllocDiag working with existing tests.

offsets not implemented.
上级 cd6ca858
...@@ -6545,41 +6545,38 @@ class AllocDiag(Op): ...@@ -6545,41 +6545,38 @@ class AllocDiag(Op):
raise ValueError('AllocDiag needs an input with 1 or more ' raise ValueError('AllocDiag needs an input with 1 or more '
'dimensions', diag.type) 'dimensions', diag.type)
return Apply( return Apply(
self, [diag],[ self, [diag],
diag.type.__class__( [diag.type.__class__(
dtype=diag.dtype, dtype=diag.dtype,
broadcastable=[False] * (diag.ndim + 1) broadcastable=[False] * (diag.ndim + 1))()]
)() )
])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
axis1 = np.minimum(self.axis1, self.axis2)
axis1 = min(axis1, axis2) axis2 = np.maximum(self.axis1, self.axis2)
axis2 = max(axis1, axis2)
# Create array with one extra dimension for resulting matrix # Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1],) * 2 result_shape = x.shape[:-1] + (x.shape[-1],) * 2
result = np.zeros(result_shape, dtype=x.dtype) result = np.zeros(result_shape, dtype=x.dtype)
# Create slice for diagonal in final 2 axes # Create slice for diagonal in final 2 axes
diagonal_slice = ((len(result_shape) - 2) * [slice(None)] + diagonal_slice = ((len(result_shape) - 2) * [slice(None)] +
[np.arange(x.shape[-1])] * 2) [np.arange(x.shape[-1])] * 2)
# Fill in final 2 axes with x # Fill in final 2 axes with x
result[diagonal_slice] = x result[diagonal_slice] = x
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2 # Re-order axes so they correspond to diagonals at axis1, axis2
axes = range(len(x.shape[:-1])) axes = range(len(x.shape[:-1]))
axes = axes[:axis1] + [axes[-1] + 1] + axes[axis1:] axes = axes[:axis1] + [axes[-1] + 1] + axes[axis1:]
axes = axes[:axis2] + [axes[-1] + 2] + axes[axis2:] axes = axes[:axis2] + [axes[-1] + 2] + axes[axis2:]
result = result.transpose(axes) result = result.transpose(axes)
z[0] = result z[0] = result
def grad(self, inputs, gout): def grad(self, inputs, gout):
(gz,) = gout (gz,) = gout
return [diagonal( return [diagonal(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论