提交 0c572c25 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed tensor_scalar_op to check that the scalar argument is indeed a scalar

上级 20f280ff
...@@ -793,7 +793,12 @@ class tensor_scalar_op(elemwise): ...@@ -793,7 +793,12 @@ class tensor_scalar_op(elemwise):
def loop_variables(cls): def loop_variables(cls):
return (['x', ], ['z', ]) return (['x', ], ['z', ])
def c_init((x, _a), (z, )): def c_init((x, _a), (z, )):
return "_a_dtype a = ((_a_dtype*)PyArray_DATA(_a))[0];" return """
if (PyArray_SIZE(_a) != 1) {
PyErr_SetString(PyExc_ValueError, \"The size of the scalar argument is not 1.\");
}
_a_dtype a = ((_a_dtype*)PyArray_DATA(_a))[0];
"""
def _c_foreach(self): def _c_foreach(self):
return "z_i = %s;" % self.c_expr return "z_i = %s;" % self.c_expr
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论