提交 97bb9197 authored 作者: Shawn Tan's avatar Shawn Tan

- Tests for AllocDiag

- Implement new infer_shape for AllocDiag - Corrected bug with perform
上级 c7eb702b
......@@ -6573,8 +6573,9 @@ class AllocDiag(Op):
if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = range(len(x.shape[:-1]))
axes = axes[:axis1] + [axes[-1] + 1] + axes[axis1:]
axes = axes[:axis2] + [axes[-1] + 2] + axes[axis2:]
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)
z[0] = result
......@@ -6589,7 +6590,15 @@ class AllocDiag(Op):
)]
def infer_shape(self, nodes, shapes):
return [(shapes[0][0],) * 2]
(x_shape,) = shapes
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
result_shape = list(x_shape[:-1])
diag_shape = x_shape[-1] + abs(self.offset)
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
return [tuple(result_shape)]
def diag(v, k=0):
......
......@@ -7561,33 +7561,37 @@ class test_diag(unittest.TestCase):
tensor.verify_grad(diag, [x], rng=rng)
def test_alloc_diag(self):
def test_alloc_diag():
dims = 4
shape = (5,) * dims
xv = np.random.randn(*shape).astype(config.floatX)
for d in xrange(1, dims + 1):
# Create a TensorType of the same dimensions as
# as the data we want to test.
x = T.TensorType(dtype=config.floatX, broadcastable=(False,) * d)('x')
x = TensorType(dtype=config.floatX, broadcastable=(False,) * d)('x')
# Make a slice of the test data that has the
# dimensions we need by doing xv[0,...,0]
# For example, for an array of shape (5,), we
# need to do xv[0, 0, 0, 0].
test_val = xv[((0,) * (dims - d))]
for offset, axis1, axis2 in [(0, 0, 1),]:
for offset, axis1, axis2 in [(0, 0, 1), (0, 1, 2), (1, 0, 1)]:
if np.maximum(axis1, axis2) > len(test_val.shape):
continue
adiag_op = AllocDiag(offset=offset,
axis1=axis1,
axis2=axis2)
f = theano.function([x], adiag_op(x))
# AllocDiag and extract the diagonal again
# to check
diag_arr = f(test_val)
rediag = np.diagonal(
f(xv),
diag_arr,
offset=offset,
axis1=axis1,
axis2=axis2
)
assert (rediag == x).all()
assert np.all(rediag == test_val)
class test_numpy_assumptions(unittest.TestCase):
# Verify that some assumptions Theano makes on Numpy's behavior still hold.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论