提交 9728b350 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

generalized IncSubtensor type checking

上级 be49e662
...@@ -2422,6 +2422,14 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2422,6 +2422,14 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs) rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
return Apply(self, [x, y] + rval.inputs[2:], [x.type()]) return Apply(self, [x, y] + rval.inputs[2:], [x.type()])
def do_type_checking(self, node):
""" Should raise NotImplementedError if c_code does not support
the types involved in this node.
"""
if not isinstance(node.inputs[0].type, CudaNdarrayType):
raise NotImplementedError()
def copy_of_x(self, x): def copy_of_x(self, x):
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals() return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals()
......
...@@ -4539,6 +4539,8 @@ class IncSubtensor(Op): ...@@ -4539,6 +4539,8 @@ class IncSubtensor(Op):
raise NotImplementedError() raise NotImplementedError()
assert response == 'y' assert response == 'y'
self.do_type_checking(node)
if self.inplace: # convert bool to int if self.inplace: # convert bool to int
inplace = 1 inplace = 1
else: else:
...@@ -4631,6 +4633,16 @@ class IncSubtensor(Op): ...@@ -4631,6 +4633,16 @@ class IncSubtensor(Op):
+ "Py_DECREF(xview);" + "Py_DECREF(xview);"
) )
def do_type_checking(self, node):
""" Should raise NotImplementedError if c_code does not support
the types involved in this node.
"""
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
def c_code_cache_version(self): def c_code_cache_version(self):
hv = Subtensor.helper_c_code_cache_version() hv = Subtensor.helper_c_code_cache_version()
if hv: if hv:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论