提交 342d5018 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a tests for alpha and output merge and do some more fixes.

上级 cd849526
...@@ -624,13 +624,13 @@ def local_gpua_hgemm(node): ...@@ -624,13 +624,13 @@ def local_gpua_hgemm(node):
@register_opt() @register_opt()
@alpha_merge(GpuGemm, alpha_in=1, beta_in=2, nd=2) @alpha_merge(GpuGemm, alpha_in=1, beta_in=4, nd=2)
def local_gpuagemm_alpha_merge(node, *inputs): def local_gpuagemm_alpha_merge(node, *inputs):
return [gpugemm_no_inplace(*inputs)] return [gpugemm_no_inplace(*inputs)]
@register_opt() @register_opt()
@output_merge(GpuGemm, alpha_in=1, beta_in=2, out_in=0, nd=2) @output_merge(GpuGemm, alpha_in=1, beta_in=4, out_in=0, nd=2)
def local_gpuagemm_output_merge(node, *inputs): def local_gpuagemm_output_merge(node, *inputs):
return [gpugemm_no_inplace(*inputs)] return [gpugemm_no_inplace(*inputs)]
......
...@@ -73,7 +73,8 @@ def alpha_merge(cls, alpha_in, beta_in, nd): ...@@ -73,7 +73,8 @@ def alpha_merge(cls, alpha_in, beta_in, nd):
lr = grab_cpu_scalar(node.inputs[0], nd=nd) lr = grab_cpu_scalar(node.inputs[0], nd=nd)
else: else:
lr = grab_cpu_scalar(node.inputs[1], nd=nd) lr = grab_cpu_scalar(node.inputs[1], nd=nd)
if lr is None or targ is None: if (lr is None or targ is None or
lr.dtype != targ.outputs[0].dtype):
return None return None
inputs = list(targ.inputs) inputs = list(targ.inputs)
try: try:
...@@ -110,6 +111,8 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd): ...@@ -110,6 +111,8 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd):
W = node.inputs[0] W = node.inputs[0]
if targ is None: if targ is None:
return None return None
if W.dtype != targ.outputs[0].dtype:
return None
if not is_equal(targ.inputs[beta_in], 0.0): if not is_equal(targ.inputs[beta_in], 0.0):
# other cases are too complex for now # other cases are too complex for now
return None return None
......
...@@ -148,3 +148,19 @@ def test_hgemm_swap(): ...@@ -148,3 +148,19 @@ def test_hgemm_swap():
on = numpy.dot(v1, v2) on = numpy.dot(v1, v2)
utt.assert_allclose(of, on) utt.assert_allclose(of, on)
def test_hgemm_alpha_output_merge():
from theano.sandbox.cuda import nvcc_compiler
if nvcc_compiler.nvcc_version < '7.5':
raise SkipTest("SgemmEx is only avaialble on cuda 7.5+")
m1 = tensor.matrix(dtype='float16')
m2 = tensor.matrix(dtype='float16')
b = tensor.matrix(dtype='float16')
hgemm = numpy.asarray(0.05, dtype='float16') * (tensor.dot(m1, m2) + b)
f = theano.function([m1, m2, b], hgemm, mode=mode_with_gpu)
# there should be 3 gpu_from_host, 1 hgemm and 1 host_from_gpu
assert len(f.maker.fgraph.apply_nodes) == 5
...@@ -36,7 +36,7 @@ class GpuArrayType(Type): ...@@ -36,7 +36,7 @@ class GpuArrayType(Type):
return self.__class__(dtype=dtype, broadcastable=broadcastable, return self.__class__(dtype=dtype, broadcastable=broadcastable,
name=self.name) name=self.name)
def __str__(self): def __repr__(self):
return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable) return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable)
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论