提交 9c1d6e29 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Better checks for ravel_index/unravel_index. Make ndim a property.

上级 eb5565b0
......@@ -2589,12 +2589,16 @@ def nonzero_values(a):
class UnravelIndex(gof.Op):
__props__ = ('order',)
__props__ = ('ndim', 'order')
def __init__(self, order='C'):
def __init__(self, ndim, order='C'):
assert order in ('C', 'F')
if not isinstance(ndim, int) or ndim < 1:
raise ValueError('ndim must be an integer greater than 0')
self.ndim = int(ndim)
self.order = order
def make_node(self, indices, dims, ndim):
def make_node(self, indices, dims):
indices = as_tensor_variable(indices)
dims = as_tensor_variable(dims)
......@@ -2604,13 +2608,11 @@ class UnravelIndex(gof.Op):
raise TypeError("'%s' object cannot be interpreted as an index" % str(dims.dtype))
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")
if not isinstance(ndim, int):
raise TypeError("ndim must be an integer")
return gof.Apply(
self, [indices, dims],
[TensorType(dtype=indices.dtype, broadcastable=(False,) * indices.ndim)()
for i in xrange(ndim)])
for i in xrange(self.ndim)])
def infer_shape(self, node, input_shapes):
return [input_shapes[0]] * len(node.outputs)
......@@ -2675,7 +2677,7 @@ def unravel_index(indices, dims, order='C', ndim=None):
"index will be. You can provide the 'ndim' keyword "
"argument to 'unravel_index' to avoid this problem." % str(dims))
res = UnravelIndex(order=order)(indices, dims, ndim)
res = UnravelIndex(ndim=ndim, order=order)(indices, dims)
if ndim == 1:
return (res,)
else:
......@@ -2686,6 +2688,8 @@ class RavelMultiIndex(gof.Op):
__props__ = ('mode', 'order')
def __init__(self, mode='raise', order='C'):
assert mode in ('raise', 'wrap', 'clip')
assert order in ('C', 'F')
self.mode = mode
self.order = order
......
......@@ -2810,7 +2810,7 @@ class test_unravel_index(utt.InferShapeTester):
# must provide integers
self.assertRaises(TypeError, unravel_index, fvector(), (3, 4))
self.assertRaises(TypeError, unravel_index, (3, 4), (3.4, 3.2))
self.assertRaises(TypeError, unravel_index, (3, 4), (3, 3), ndim=5.4)
self.assertRaises(ValueError, unravel_index, (3, 4), (3, 3), ndim=5.4)
# dims must be a 1D sequence
self.assertRaises(TypeError, unravel_index, (3, 4), 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论