提交 1b101ffc authored 作者: Tim Cooijmans's avatar Tim Cooijmans 提交者: Reyhane Askari

whoops

上级 ff9e2b38
......@@ -259,7 +259,7 @@ def op_lifter(OP, cuda_only=False):
new_outputs = new_op
to_cpu_fn = safe_to_cpu
else: # suppose it is a variable on the GPU
new_outputs = new_op]
new_outputs = [new_op]
to_cpu_fn = lambda x: x.transfer('cpu')
# copy stack traces onto gpu outputs
for old_output, new_output in zip(node.outputs, new_outputs):
......@@ -1400,11 +1400,12 @@ def local_gpua_dot22(op, context_name, inputs, outputs):
@op_lifter([tensor.blas.Dot22Scalar])
@register_opt2([tensor.blas.Dot22Scalar], 'fast_compile')
def local_gpua_dot22scalar(op, context_name, inputs, outputs):
x, y, a = inputs
x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
return [gpugemm_no_inplace(z, a, x, y, 0)]
with inherit_stack_trace(outputs):
x, y, a = inputs
x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
return [gpugemm_no_inplace(z, a, x, y, 0)]
@register_opt('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论