提交 452e9c92 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the blas ops so that they really don't need to inherit from the

tensor ones and make GpuGer conform.
上级 5cecb66b
......@@ -31,6 +31,11 @@ class BlasOp(HideC):
class GpuGemv(BlasOp):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, y, alpha, A, x, beta):
ctx_name = infer_context_name(y, A, x)
A = as_gpuarray_variable(A, ctx_name)
......@@ -104,6 +109,11 @@ gpugemv_inplace = GpuGemv(inplace=True)
class GpuGemm(BlasOp):
_f16_ok = True
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, C, alpha, A, B, beta):
ctx_name = infer_context_name(C, A, B)
A = as_gpuarray_variable(A, ctx_name)
......@@ -175,6 +185,11 @@ gpugemm_inplace = GpuGemm(inplace=True)
class GpuGer(BlasOp):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, A, alpha, x, y):
ctx_name = infer_context_name(A, x, y)
A = as_gpuarray_variable(A, ctx_name)
......@@ -236,8 +251,8 @@ class GpuGer(BlasOp):
return (3,)
gpuger_no_inplace = GpuGer(destructive=False)
gpuger_inplace = GpuGer(destructive=True)
gpuger_no_inplace = GpuGer(inplace=False)
gpuger_inplace = GpuGer(inplace=True)
class GpuDot22(BlasOp):
......
......@@ -728,7 +728,7 @@ def local_gpuagemm_output_merge(node, *inputs):
@register_opt('fast_compile')
@op_lifter([tensor.blas.Ger, tensor.blas_c.CGer, tensor.blas_scipy.ScipyGer])
def local_gpua_ger(node, context_name):
return GpuGer(destructive=node.op.destructive)
return GpuGer(inplace=node.op.destructive)
@register_opt('fast_compile')
......
......@@ -100,7 +100,7 @@ class TestGpuGer_OpContract(TestCase, utt.T_OpContractMixin):
self.ops = [gpuger_no_inplace, gpuger_inplace]
def clone(self, op):
return GpuGer(destructive=op.destructive)
return GpuGer(inplace=op.inplace)
GpuDot22Tester = makeTester(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论