提交 c4266e45 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1940 from Hengjean/CheckNDimSubtensor

Added ndim check to subtensor.
...@@ -846,11 +846,21 @@ class Subtensor(Op): ...@@ -846,11 +846,21 @@ class Subtensor(Op):
x = inputs[0] x = inputs[0]
z, = outputs z, = outputs
ndim = node.inputs[0].ndim
view_ndim = node.outputs[0].ndim view_ndim = node.outputs[0].ndim
fail = sub['fail'] fail = sub['fail']
decl = "PyArrayObject * xview = NULL;" decl = "PyArrayObject * xview = NULL;"
checkNDim = """
if (PyArray_NDIM(%(x)s) != %(ndim)s){
PyErr_SetString(PyExc_ValueError,
"Expected %(ndim)s dimensions input"
);
%(fail)s
}
""" % locals()
get_xview = self.helper_c_code(node, name, inputs, outputs, sub, get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list, view_ndim) self.idx_list, view_ndim)
build_view = """ build_view = """
...@@ -887,7 +897,7 @@ class Subtensor(Op): ...@@ -887,7 +897,7 @@ class Subtensor(Op):
%(z)s = xview; %(z)s = xview;
""" % locals() """ % locals()
return decl + get_xview + build_view + finish_view return decl + checkNDim + "{" + get_xview + build_view + finish_view + "}"
def c_code_cache_version(self): def c_code_cache_version(self):
hv = self.helper_c_code_cache_version() hv = self.helper_c_code_cache_version()
...@@ -895,7 +905,7 @@ class Subtensor(Op): ...@@ -895,7 +905,7 @@ class Subtensor(Op):
# have a versioned version of this op's C code. # have a versioned version of this op's C code.
if len(hv) == 0: if len(hv) == 0:
return () return ()
return (3, hv) return (4, hv)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# Subtensor is not differentiable wrt to its indices, therefore we # Subtensor is not differentiable wrt to its indices, therefore we
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论