提交 6b98e587 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4065 from abergeron/lift_dot22scalar

Also lift Dot22Scalar.
......@@ -750,6 +750,16 @@ def local_gpua_dot22(node, context_name):
return gpu_dot22
@register_opt('fast_compile')
@op_lifter([tensor.blas.Dot22Scalar])
def local_gpua_dot22scalar(node, context_name):
x, y, a = node.inputs
x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
return [GpuGemm(inplace=False)(z, a, x, y, 0)]
@register_opt('fast_compile')
@op_lifter([tensor.basic.Eye])
def local_gpua_eye(node, context_name):
......
......@@ -11,6 +11,7 @@ from .. import basic_ops
from ..type import GpuArrayType, gpuarray_shared_constructor, get_context
from ..basic_ops import (
GpuAlloc, GpuAllocEmpty, GpuReshape, GpuFromHost, host_from_gpu)
from ..blas import GpuGemm
from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise
from ..subtensor import GpuSubtensor
......@@ -255,6 +256,23 @@ def test_local_gpu_elemwise_careduce():
utt.assert_allclose(f(data), (data * data).sum(axis=1))
def test_local_lift_dot22scalar():
x = tensor.matrix()
y = tensor.matrix()
a = tensor.scalar()
o = tensor.blas.Dot22Scalar()(x, y, a)
f_cpu = theano.function([x, y, a], o)
f_gpu = theano.function([x, y, a], o, mode=mode_with_gpu)
assert not any(isinstance(n.op, tensor.blas.Dot22Scalar)
for n in f_gpu.maker.fgraph.apply_nodes)
assert any(isinstance(n.op, GpuGemm)
for n in f_gpu.maker.fgraph.apply_nodes)
x_val = numpy.random.random((2, 3)).astype(theano.config.floatX)
y_val = numpy.random.random((3, 4)).astype(theano.config.floatX)
a_val = 0.5
utt.assert_allclose(f_cpu(x_val, y_val, a_val), f_gpu(x_val, y_val, a_val))
def test_local_gpu_subtensor():
# Test shared forced on CPU.
t = tensor._shared(numpy.zeros(20, "float32"))
......
......@@ -473,6 +473,9 @@ class Gemv(Op):
out += y
out_storage[0][0] = numpy.asarray(out, dtype=y.dtype)
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True)
# For the user interface. Opt will make them inplace later
......@@ -540,6 +543,9 @@ class Ger(Op):
A += numpy.outer(cx, cy)
cZ[0] = A
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
ger = Ger(destructive=False)
ger_destructive = Ger(destructive=True)
......@@ -1001,6 +1007,8 @@ class Gemm(GemmRelated):
E_mixed = 'gemm requires matching dtypes'
E_float = 'gemm requires floating-point dtypes'
__props__ = ('inplace',)
def __init__(self, inplace):
self.inplace = inplace
if self.inplace:
......@@ -1009,13 +1017,6 @@ class Gemm(GemmRelated):
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __eq__(self, other):
return (type(self) == type(other) and
self.inplace == other.inplace)
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
if self.inplace:
inplace_str = 'inplace'
......@@ -1124,6 +1125,9 @@ class Gemm(GemmRelated):
z += a * numpy.dot(x, y)
zout[0] = z
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
setup_z_Nz_Sz_inplace = """
if (%(_zout)s != %(_z)s)
{
......@@ -1747,8 +1751,8 @@ class Dot22(GemmRelated):
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return self.__class__.__name__
def infer_shape(self, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]
setup_z_Nz_Sz = """
if ((NULL == %(_zout)s)
......@@ -2018,8 +2022,8 @@ class Dot22Scalar(GemmRelated):
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return self.__class__.__name__
def infer_shape(self, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
......
......@@ -2130,3 +2130,62 @@ class TestBlasStrides(TestCase):
self.cmp_ger((0, 1), 0, 1)
self.cmp_ger((1, 0), 1, 0)
self.cmp_ger((0, 0), 0, 0)
class test_infer_shape(unittest_tools.InferShapeTester):
def test_dot22(self):
x, y = T.matrices('xy')
self._compile_and_check(
[x, y], [T.blas._dot22(x, y)],
[numpy.random.random((2, 3)).astype(config.floatX),
numpy.random.random((3, 4)).astype(config.floatX)],
T.blas.Dot22)
def test_dot22scalar(self):
x, y = T.matrices('xy')
a = T.scalar('a')
self._compile_and_check(
[x, y, a], [T.blas._dot22scalar(x, y, a)],
[numpy.random.random((2, 3)).astype(config.floatX),
numpy.random.random((3, 4)).astype(config.floatX),
numpy.asarray(0.5, dtype=config.floatX)],
T.blas.Dot22Scalar)
def test_gemm(self):
x, y, z = T.matrices('xyz')
a = T.scalar('a')
b = T.scalar('b')
self._compile_and_check(
[x, y, a, z, b], [T.blas.gemm(z, a, x, y, b)],
[numpy.random.random((2, 3)).astype(config.floatX),
numpy.random.random((3, 4)).astype(config.floatX),
numpy.asarray(0.5, dtype=config.floatX),
numpy.random.random((2, 4)).astype(config.floatX),
numpy.asarray(0.5, dtype=config.floatX)],
T.blas.Gemm)
def test_gemv(self):
A = T.matrix('A')
x, y = T.vectors('xy')
a = T.scalar('a')
b = T.scalar('b')
self._compile_and_check(
[y, a, A, x, b], [T.blas.gemv(y, a, A, x, b)],
[numpy.random.random((2,)).astype(config.floatX),
numpy.asarray(0.5, dtype=config.floatX),
numpy.random.random((2, 3)).astype(config.floatX),
numpy.random.random((3,)).astype(config.floatX),
numpy.asarray(0.5, dtype=config.floatX)],
T.blas.Gemv)
def test_ger(self):
A = T.matrix('A')
x, y = T.vectors('xy')
a = T.scalar('a')
self._compile_and_check(
[A, a, x, y], [T.blas.ger(A, a, x, y)],
[numpy.random.random((2, 3)).astype(config.floatX),
numpy.asarray(0.5, dtype=config.floatX),
numpy.random.random((2,)).astype(config.floatX),
numpy.random.random((3,)).astype(config.floatX)],
T.blas.Ger)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论