提交 9728b350 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

generalized IncSubtensor type checking

上级 be49e662
......@@ -2422,6 +2422,14 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
return Apply(self, [x, y] + rval.inputs[2:], [x.type()])
def do_type_checking(self, node):
""" Should raise NotImplementedError if c_code does not support
the types involved in this node.
"""
if not isinstance(node.inputs[0].type, CudaNdarrayType):
raise NotImplementedError()
def copy_of_x(self, x):
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals()
......
......@@ -4539,6 +4539,8 @@ class IncSubtensor(Op):
raise NotImplementedError()
assert response == 'y'
self.do_type_checking(node)
if self.inplace: # convert bool to int
inplace = 1
else:
......@@ -4631,6 +4633,16 @@ class IncSubtensor(Op):
+ "Py_DECREF(xview);"
)
def do_type_checking(self, node):
""" Should raise NotImplementedError if c_code does not support
the types involved in this node.
"""
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
def c_code_cache_version(self):
hv = Subtensor.helper_c_code_cache_version()
if hv:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论