提交 4f1c2697 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Type context for subtensor.py.

上级 c887bc14
......@@ -26,10 +26,12 @@ class GpuSubtensor(HideC, Subtensor):
_f16_ok = True
def make_node(self, x, *inputs):
ctx_name = infer_context_name(x)
rval = tensor.Subtensor.make_node(self, x, *inputs)
otype = GpuArrayType(dtype=rval.outputs[0].type.dtype,
broadcastable=rval.outputs[0].type.broadcastable)
x = as_gpuarray_variable(x)
broadcastable=rval.outputs[0].type.broadcastable,
context_name=ctx_name)
x = as_gpuarray_variable(x, ctx_name)
return gof.Apply(self, [x] + rval.inputs[1:], [otype()])
def perform(self, node, inputs, out_):
......@@ -190,14 +192,18 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor):
return self.iadd_node.op.gpu_kernels(self.iadd_node, subname)
def make_node(self, x, y, *inputs):
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
ctx_name = infer_context_name(x, y)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
op = copy.copy(self)
ret = gof.Apply(op, [x, y] + rval.inputs[2:], [x.type()])
op.create_iadd_node(ret)
return ret
def get_context(self, node):
return node.outputs[0].type.context
def create_iadd_node(self, node):
# We store a iadd_node in the op that contain the info needed
# for the inplace add.
......@@ -365,10 +371,10 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor):
""" % locals()
inputs = ["dst", "src"]
outputs = ["ret"]
sub = {"fail": "return NULL;"}
sub = {"fail": "return NULL;", "context": "dst->ctx"}
ret += gop.c_code(self.iadd_node, sub_name, inputs, outputs, sub)
ret += """
return dst;
return ret;
}
"""
return ret
......@@ -398,7 +404,8 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor):
class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1):
def make_node(self, x, ilist):
x_ = as_gpuarray_variable(x)
ctx_name = infer_context_name(x, ilist)
x_ = as_gpuarray_variable(x, ctx_name)
ilist__ = tensor.as_tensor_variable(ilist)
if ilist__.type.dtype[:3] not in ('int', 'uin'):
......@@ -406,7 +413,7 @@ class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1):
if ilist__.type.dtype != 'int64':
ilist__ = tensor.cast(ilist__, 'int64')
ilist_ = as_gpuarray_variable(ilist__)
ilist_ = as_gpuarray_variable(ilist__, ctx_name)
if ilist_.type.dtype != 'int64':
raise TypeError('index must be int64')
......@@ -418,6 +425,7 @@ class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1):
bcast = ilist_.broadcastable + x_.broadcastable[1:]
return gof.Apply(self, [x_, ilist_],
[GpuArrayType(dtype=x.dtype,
context_name=ctx_name,
broadcastable=bcast)()])
def perform(self, node, inp, out_):
......@@ -474,8 +482,9 @@ class GpuAdvancedIncSubtensor1(HideC, tensor.AdvancedIncSubtensor1):
"""
def make_node(self, x, y, ilist):
x_ = as_gpuarray_variable(x)
y_ = as_gpuarray_variable(y)
ctx_name = infer_context_name(x, y)
x_ = as_gpuarray_variable(x, ctx_name)
y_ = as_gpuarray_variable(y, ctx_name)
ilist_ = tensor.as_tensor_variable(ilist)
assert x_.type.dtype == y_.type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论