提交 452e9c92 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the blas ops so that they really don't need to inherit from the

tensor ones and make GpuGer conform.
上级 5cecb66b
...@@ -31,6 +31,11 @@ class BlasOp(HideC): ...@@ -31,6 +31,11 @@ class BlasOp(HideC):
class GpuGemv(BlasOp): class GpuGemv(BlasOp):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, y, alpha, A, x, beta): def make_node(self, y, alpha, A, x, beta):
ctx_name = infer_context_name(y, A, x) ctx_name = infer_context_name(y, A, x)
A = as_gpuarray_variable(A, ctx_name) A = as_gpuarray_variable(A, ctx_name)
...@@ -104,6 +109,11 @@ gpugemv_inplace = GpuGemv(inplace=True) ...@@ -104,6 +109,11 @@ gpugemv_inplace = GpuGemv(inplace=True)
class GpuGemm(BlasOp): class GpuGemm(BlasOp):
_f16_ok = True _f16_ok = True
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, C, alpha, A, B, beta): def make_node(self, C, alpha, A, B, beta):
ctx_name = infer_context_name(C, A, B) ctx_name = infer_context_name(C, A, B)
A = as_gpuarray_variable(A, ctx_name) A = as_gpuarray_variable(A, ctx_name)
...@@ -175,6 +185,11 @@ gpugemm_inplace = GpuGemm(inplace=True) ...@@ -175,6 +185,11 @@ gpugemm_inplace = GpuGemm(inplace=True)
class GpuGer(BlasOp): class GpuGer(BlasOp):
def __init__(self, inplace=False):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, A, alpha, x, y): def make_node(self, A, alpha, x, y):
ctx_name = infer_context_name(A, x, y) ctx_name = infer_context_name(A, x, y)
A = as_gpuarray_variable(A, ctx_name) A = as_gpuarray_variable(A, ctx_name)
...@@ -236,8 +251,8 @@ class GpuGer(BlasOp): ...@@ -236,8 +251,8 @@ class GpuGer(BlasOp):
return (3,) return (3,)
gpuger_no_inplace = GpuGer(destructive=False) gpuger_no_inplace = GpuGer(inplace=False)
gpuger_inplace = GpuGer(destructive=True) gpuger_inplace = GpuGer(inplace=True)
class GpuDot22(BlasOp): class GpuDot22(BlasOp):
......
...@@ -728,7 +728,7 @@ def local_gpuagemm_output_merge(node, *inputs): ...@@ -728,7 +728,7 @@ def local_gpuagemm_output_merge(node, *inputs):
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([tensor.blas.Ger, tensor.blas_c.CGer, tensor.blas_scipy.ScipyGer]) @op_lifter([tensor.blas.Ger, tensor.blas_c.CGer, tensor.blas_scipy.ScipyGer])
def local_gpua_ger(node, context_name): def local_gpua_ger(node, context_name):
return GpuGer(destructive=node.op.destructive) return GpuGer(inplace=node.op.destructive)
@register_opt('fast_compile') @register_opt('fast_compile')
......
...@@ -100,7 +100,7 @@ class TestGpuGer_OpContract(TestCase, utt.T_OpContractMixin): ...@@ -100,7 +100,7 @@ class TestGpuGer_OpContract(TestCase, utt.T_OpContractMixin):
self.ops = [gpuger_no_inplace, gpuger_inplace] self.ops = [gpuger_no_inplace, gpuger_inplace]
def clone(self, op): def clone(self, op):
return GpuGer(destructive=op.destructive) return GpuGer(inplace=op.inplace)
GpuDot22Tester = makeTester( GpuDot22Tester = makeTester(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论