提交 35aff174 authored 作者: Frederic's avatar Frederic

Do the same change to GpuDimshuffle then what was done to Dimshuffle

上级 03924847
...@@ -327,10 +327,19 @@ class GpuDimShuffle(GpuOp): ...@@ -327,10 +327,19 @@ class GpuDimShuffle(GpuOp):
def make_node(self, input): def make_node(self, input):
ib = tuple(input.type.broadcastable) ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable: if not ib == self.input_broadcastable:
raise TypeError( if len(ib) != len(self.input_broadcastable):
"The number of dimensions and/or broadcastable pattern of the" raise TypeError((
" input is incorrect for this op. Expected %s, got %s." % "The number of dimensions of the "
(self.input_broadcastable, ib)) "input is incorrect for this op. Expected %s, got %s."
% (self.input_broadcastable, ib)))
for expected, b in zip(self.input_broadcastable, ib):
if expected is True and b is False:
raise TypeError((
"The broadcastable pattern of the "
"input is incorrect for this op. Expected %s, got %s."
% (self.input_broadcastable, ib)))
#else, expected == b or expected is False and b is True
# Both case are good.
ob = [] ob = []
if not isinstance(input.type, CudaNdarrayType): if not isinstance(input.type, CudaNdarrayType):
raise TypeError("The input of a GpuDimshuffle must" raise TypeError("The input of a GpuDimshuffle must"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论