提交 c4596efb authored 作者: Frederic Bastien's avatar Frederic Bastien

New op Dot22Scalar. It do like Dot22, But multiply each element by a constant.

上级 ba0aedbb
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions""" """Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import sys, traceback, logging import sys, traceback, logging, copy
import numpy import numpy
import numpy.distutils import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
...@@ -9,7 +9,8 @@ from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler ...@@ -9,7 +9,8 @@ from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer) InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer)
from theano.printing import pprint, FunctionPrinter from theano.printing import pprint, FunctionPrinter
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.gof.python25 import any
import theano.scalar
import basic as T import basic as T
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
...@@ -762,3 +763,159 @@ optdb.register('InplaceBlasOpt', ...@@ -762,3 +763,159 @@ optdb.register('InplaceBlasOpt',
max_use_ratio=5), max_use_ratio=5),
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
class Dot22Scalar(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
Used to call optimized gemm implementation.
Also used to generate a gemm later.
compute scalar*dot(x,y)
"""
def make_node(self, x, y, scalar):
if not _is_real_matrix(x):
raise TypeError(x)
if not _is_real_matrix(x):
raise TypeError(y)
if not _as_scalar(scalar):
raise TypeError(scalar)
if y.type.dtype != x.type.dtype and y.type.dtype != scalar.type.dtype:
raise TypeError('dtype mismatch to Dot22Scalar')
out_shape = (x.type.shape[0], y.type.shape[1])
bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz, shape=out_shape)]
return Apply(self, [x,y,scalar], outputs)
def perform(self, node, (x, y, scalar), (z, )):
try:
z[0] = scalar * numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22scalar"
setup_z_Nz_Sz = """
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22scalar output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
"""
check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
"""
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
#undef REAL
float b = 0.0;
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
#undef REAL
double b = 0.0;
"""
def c_code(self, node, name, (_x, _y, _a), (_z, ), sub): #DEBUG
if len(self.c_libraries())<=0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_z, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
_dot22scalar = Dot22Scalar()
@local_optimizer([T.mul])
def local_dot22_to_dot22scalar(node):
"""
:note: we upcast the scalar if after the multiplication with the dot this give the same type.
.. note:
We execute this optimizer after the gemm optimizer. This allow to give more priority to gemm that give more speed up then this optimizer, but allow the gemm optimizer to ignore this op.
TODO: support when we can reorder the mul to generate a dot22scalar or fix the canonizer to merge them(1 mul with multiple inputs)
"""
if node.op != T.mul:
return False
i_dot22 = [x.owner and x.owner.op==_dot22 for x in node.inputs]
if not any(i_dot22): return False # no dot22
if i_dot22.count(True)>1: return False #TODO fix
#we take the first _dot22 found. TODO check others!
dot22_idx = i_dot22.index(True)
d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x) for x in node.inputs]
if not any(i_scalar) and not any([x.owner and x.owner.op ==T.mul for x in node.inputs]):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False
if not any(i_scalar):
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
mul_idx = i_mul.index(True)#we take the first mul!
m = node.inputs[mul_idx]
if len(m.owner.inputs)==2 and any([_as_scalar(x) for x in m.owner.inputs]):
scalar_idx = 0
for i,x in enumerate(m.owner.inputs):
if _as_scalar(x):
scalar_idx=i
break
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1],m.owner.inputs[scalar_idx])
return [T.mul(m.owner.inputs[1-i],dot)]
elif m.owner and m.owner.op == T.mul:
info('Not optimizing dot22 with inputs', d, m, d.type, m.type, 'we need to check in a recursive way in the mul if we can reorder the graph. The canonizer should have done this.')
else:
return False
scalar_idx = -1
for i,x in enumerate(node.inputs):
if i_scalar[i] and theano.scalar.upcast(x.type.dtype,d.type.dtype) == d.type.dtype:
scalar_idx = i
break
if scalar_idx<0:
info('Not optimizing dot22 with inputs', node.inputs, [x.type for x in node.inputs], 'as the type of the scalar can\'t be upcasted to the matrix type')
return False
assert scalar_idx<len(node.inputs)
s = node.inputs[scalar_idx]
o = copy.copy(node.inputs)
o.remove(d)
o.remove(s)
if len(o)==0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], s)]
else:
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], s), *o)]
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar',
EquilibriumOptimizer([local_dot22_to_dot22scalar ], max_use_ratio=5),
11, 'fast_run')
...@@ -4,7 +4,7 @@ from theano.gof import Env ...@@ -4,7 +4,7 @@ from theano.gof import Env
from theano.printing import pp from theano.printing import pp
import numpy, theano import numpy, theano
from theano.tensor.blas import * from theano.tensor.blas import *
from theano.tensor.blas import _dot22, res_is_a, _as_scalar, _is_real_matrix from theano.tensor.blas import _dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix
from unittest import TestCase from unittest import TestCase
from theano.tests import unittest_tools from theano.tests import unittest_tools
from copy import copy from copy import copy
...@@ -440,3 +440,58 @@ def test_dot22(): ...@@ -440,3 +440,58 @@ def test_dot22():
bv=numpy.random.rand(5,5) bv=numpy.random.rand(5,5)
f(av,bv) f(av,bv)
def test_dot22scalar():
m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize')
a=T.matrix()
b=T.matrix()
c=T.matrix()
av=numpy.random.rand(5,5)
bv=numpy.random.rand(5,5)
cv=numpy.random.rand(5,5)
if True:
f = theano.function([a,b],0.2*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==1
f(av,bv)
if True:
f = theano.function([a,b,c],0.2*c*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*a*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],0.2*c *a*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together...
#that force the optimizer to make a recursive search witch it don't do now.
#but it do it for 1 level of recursion.
# assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * a*0.2*T.dot(a,b),mode=m)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论