提交 1201efa1 authored 作者: Frederic's avatar Frederic

correctly check input variable in Fourrier.make_node

上级 5f89d1ec
...@@ -55,25 +55,32 @@ class Fourier(gof.Op): ...@@ -55,25 +55,32 @@ class Fourier(gof.Op):
if axis is None: if axis is None:
axis = a.ndim - 1 axis = a.ndim - 1
axis = tensor.as_tensor_variable(axis) axis = tensor.as_tensor_variable(axis)
elif (not axis.dtype.startswith('int')) and \ else:
(not axis.dtype.startswith('uint')): axis = tensor.as_tensor_variable(axis)
raise TypeError('%s: index of the transformed axis must be' if (not axis.dtype.startswith('int')) and \
' of type integer' % self.__class__.__name__) (not axis.dtype.startswith('uint')):
elif axis.ndim != 0 or axis < 0 or axis > a.ndim - 1: raise TypeError('%s: index of the transformed axis must be'
raise TypeError('%s: index of the transformed axis must be' ' of type integer' % self.__class__.__name__)
' a scalar not smaller than 0 and smaller than' elif axis.ndim != 0 or (isinstance(axis, tensor.TensorConstant) and
' dimension of array' % self.__class__.__name__) (axis.data < 0 or axis.data > a.ndim - 1)):
raise TypeError('%s: index of the transformed axis must be'
' a scalar not smaller than 0 and smaller than'
' dimension of array' % self.__class__.__name__)
if n is None: if n is None:
n = a.shape[axis] n = a.shape[axis]
n = tensor.as_tensor_variable(n) n = tensor.as_tensor_variable(n)
elif (not n.dtypestartswith('int')) and \ else:
(not n.dtypestartswith('uint')): n = tensor.as_tensor_variable(n)
raise TypeError('%s: length of the transformed axis must be' if (not n.dtype.startswith('int')) and \
' of type integer' % self.__class__.__name__) (not n.dtype.startswith('uint')):
elif n.ndim != 0 or n < 1: raise TypeError('%s: length of the transformed axis must be'
raise TypeError('%s: length of the transformed axis must be a' ' of type integer' % self.__class__.__name__)
' strictly positive scalar' elif n.ndim != 0 or (isinstance(n, tensor.TensorConstant) and
% self.__class__.__name__) n.data < 1):
raise TypeError('%s: length of the transformed axis must be a'
' strictly positive scalar'
% self.__class__.__name__)
return gof.Apply(self, [a, n, axis], [tensor.TensorType('complex128', return gof.Apply(self, [a, n, axis], [tensor.TensorType('complex128',
a.type.broadcastable)()]) a.type.broadcastable)()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论