提交 480ca3ad authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Replace Dot22Scalar by a GpuGemm instead and add a test.

上级 622f1d51
......@@ -753,9 +753,11 @@ def local_gpua_dot22(node, context_name):
@register_opt('fast_compile')
@op_lifter([tensor.blas.Dot22Scalar])
def local_gpua_dot22scalar(node, context_name):
dot = gpu_dot22(as_gpuarray_variable(node.inputs[0], context_name),
as_gpuarray_variable(node.inputs[1], context_name))
return [node.inputs[2] * dot]
x, y, a = node.inputs
x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
return [GpuGemm(inplace=False)(z, a, x, y, 0)]
@register_opt('fast_compile')
......
......@@ -11,6 +11,7 @@ from .. import basic_ops
from ..type import GpuArrayType, gpuarray_shared_constructor, get_context
from ..basic_ops import (
GpuAlloc, GpuAllocEmpty, GpuReshape, GpuFromHost, host_from_gpu)
from ..blas import GpuGemm
from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise
from ..subtensor import GpuSubtensor
......@@ -254,6 +255,17 @@ def test_local_gpu_elemwise_careduce():
assert topo[1].op.pre_scalar_op == theano.scalar.sqr
utt.assert_allclose(f(data), (data * data).sum(axis=1))
def test_local_lift_dot22scalar():
x = tensor.matrix()
y = tensor.matrix()
a = tensor.scalar()
o = tensor.blas.Dot22Scalar()(x, y, a)
f = theano.function([x, y, a], o, mode=mode_with_gpu)
assert not any(isinstance(n.op, tensor.blas.Dot22Scalar)
for n in f.maker.fgraph.apply_nodes)
assert any(isinstance(n.op, GpuGemm)
for n in f.maker.fgraph.apply_nodes)
def test_local_gpu_subtensor():
# Test shared forced on CPU.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论