提交 46eaf9df authored 作者: Reyhane Askari's avatar Reyhane Askari

unstable commit-fixed perform on gpujoin

上级 98a120b4
...@@ -1245,8 +1245,27 @@ class GpuJoin(HideC, Join): ...@@ -1245,8 +1245,27 @@ class GpuJoin(HideC, Join):
""" """
_f16_ok = True _f16_ok = True
__props__ = ("view",)
params_type = gpu_context_type params_type = gpu_context_type
def __init__(self, view=-1):
self.view = view
if view != -1:
# since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
# def __str__(self):
# if self.view == -1:
# return "Join"
# else:
# return super(Join, self).__str__()
# def __setstate__(self, d):
# self.__dict__.update(d)
# if not hasattr(self, "view"):
# self.view = -1
def make_node(self, axis, *tensors): def make_node(self, axis, *tensors):
node = Join.make_node(self, axis, *tensors) node = Join.make_node(self, axis, *tensors)
...@@ -1265,26 +1284,36 @@ class GpuJoin(HideC, Join): ...@@ -1265,26 +1284,36 @@ class GpuJoin(HideC, Join):
def perform(self, node, axis_and_tensors, out_, ctx): def perform(self, node, axis_and_tensors, out_, ctx):
out, = out_ out, = out_
view = self.view
axis = int(axis_and_tensors[0]) axis = int(axis_and_tensors[0])
tensors = axis_and_tensors[1:]
if axis < -axis_and_tensors[1].ndim: if axis < -axis_and_tensors[1].ndim:
raise IndexError raise IndexError
if axis < 0: if axis < 0:
axis += axis_and_tensors[1].ndim axis += axis_and_tensors[1].ndim
tensors = axis_and_tensors[1:] # we check these tensors for being empty.
out[0] = pygpu.concatenate(tensors, axis=axis, context=ctx).astype( if (view != -1) and numpy.all(
node.outputs[0].dtype) [tensor.shape[axis] == 0 for tensor in
tensors[0:view] + tensors[view + 1:]]):
import ipdb; ipdb.set_trace()
out[0] = tensors[view]
else:
out[0] = pygpu.concatenate(tensors, axis=axis, context=ctx).astype(
node.outputs[0].dtype)
def c_code_cache_version(self): def c_code_cache_version(self):
return
return (2,) return (2,)
def c_support_code(self): def c_support_code_(self):
return """ return """
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
#define PyInt_AsLong PyLong_AsLong #define PyInt_AsLong PyLong_AsLong
#endif #endif
""" """
def c_code(self, node, name, inputs, out_, sub): def c_code_(self, node, name, inputs, out_, sub):
copy_to_list = [] copy_to_list = []
restype = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype) restype = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
for i, inp in enumerate(inputs[1:]): for i, inp in enumerate(inputs[1:]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论