提交 2c290d27 authored 作者: James Bergstra's avatar James Bergstra

stricter perform for DimShuffle

上级 7e1ca580
...@@ -160,6 +160,8 @@ class DimShuffle(Op): ...@@ -160,6 +160,8 @@ class DimShuffle(Op):
def perform(self, node, (input, ), (storage, )): def perform(self, node, (input, ), (storage, )):
# drop # drop
res = input res = input
if type(res) != numpy.ndarray:
raise TypeError(res)
shape = list(res.shape) shape = list(res.shape)
for drop in reversed(self.drop): for drop in reversed(self.drop):
shape.pop(drop) shape.pop(drop)
...@@ -178,7 +180,7 @@ class DimShuffle(Op): ...@@ -178,7 +180,7 @@ class DimShuffle(Op):
if not self.inplace: if not self.inplace:
res = numpy.copy(res) res = numpy.copy(res)
storage[0] = res storage[0] = numpy.asarray(res) #asarray puts scalars back into array
def c_code(self, node, name, (input,), (res,), sub): def c_code(self, node, name, (input,), (res,), sub):
def statements(lst): def statements(lst):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论