提交 a8646fdc authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #77 from nouiz/dot22_complex

Dot22 complex still w pep8 errors but Fred's busy.
...@@ -1218,9 +1218,10 @@ class Dot22(GemmRelated): ...@@ -1218,9 +1218,10 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
""" """
def make_node(self, x, y): def make_node(self, x, y):
if x.type.ndim != 2 or x.type.dtype not in ('float32', 'float64'): dtypes = ('float32', 'float64', 'complex64', 'complex128')
if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x) raise TypeError(x)
if y.type.ndim != 2 or y.type.dtype not in ('float32', 'float64'): if y.type.ndim != 2 or y.type.dtype not in dtypes:
raise TypeError(y) raise TypeError(y)
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
raise TypeError('dtype mismatch to Dot22') raise TypeError('dtype mismatch to Dot22')
...@@ -1301,7 +1302,7 @@ def local_dot_to_dot22(node): ...@@ -1301,7 +1302,7 @@ def local_dot_to_dot22(node):
_logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type) _logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type)
return return
if y.type.dtype.startswith('float'): if y.type.dtype.startswith('float') or y.type.dtype.startswith('complex'):
if x.ndim == 2 and y.ndim == 2: if x.ndim == 2 and y.ndim == 2:
#print "local_dot_to_dot22: MM" #print "local_dot_to_dot22: MM"
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
...@@ -1502,6 +1503,9 @@ class Dot22Scalar(GemmRelated): ...@@ -1502,6 +1503,9 @@ class Dot22Scalar(GemmRelated):
def c_code(self, node, name, inp, out, sub): #DEBUG def c_code(self, node, name, inp, out, sub): #DEBUG
_x, _y, _a = inp _x, _y, _a = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if len(self.c_libraries())<=0: if len(self.c_libraries())<=0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub) return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
...@@ -1550,12 +1554,18 @@ def local_dot22_to_dot22scalar(node): ...@@ -1550,12 +1554,18 @@ def local_dot22_to_dot22scalar(node):
m = node.inputs[mul_idx] m = node.inputs[mul_idx]
if len(m.owner.inputs)==2 and any([_as_scalar(x) for x in m.owner.inputs]): if len(m.owner.inputs)==2 and any([_as_scalar(x) for x in m.owner.inputs]):
scalar_idx = 0 scalar_idx = -1
for i,x in enumerate(m.owner.inputs): for i,x in enumerate(m.owner.inputs):
if _as_scalar(x): if _as_scalar(x) and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
== d.type.dtype):
scalar_idx=i scalar_idx=i
break break
if scalar_idx<0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs])
return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx]), d.type.dtype) a = T.cast(_as_scalar(m.owner.inputs[scalar_idx]), d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
......
...@@ -674,23 +674,29 @@ def test_inplace1(): ...@@ -674,23 +674,29 @@ def test_inplace1():
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22(): def test_dot22():
a=T.matrix() for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
b=T.matrix() a=T.matrix(dtype = dtype1)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt) for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
topo = f.maker.env.toposort() b=T.matrix(dtype = dtype2)
assert _dot22 in [x.op for x in topo] f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt)
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) topo = f.maker.env.toposort()
if dtype1 == dtype2:
def cmp(a_shp, b_shp): assert _dot22 in [x.op for x in topo], (dtype1,dtype2)
av=rng.uniform(size=a_shp).astype(config.floatX) else:
bv=rng.uniform(size=b_shp).astype(config.floatX) assert T.dot in [x.op for x in topo], (dtype1,dtype2)
f(av,bv) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
cmp((3,4),(4,5))
cmp((0,4),(4,5)) def cmp(a_shp, b_shp):
cmp((3,0),(0,5)) av=rng.uniform(size=a_shp).astype(dtype1)
cmp((3,4),(4,0)) bv=rng.uniform(size=b_shp).astype(dtype2)
cmp((0,4),(4,0)) f(av,bv)
cmp((0,0),(0,0))
cmp((3,4),(4,5))
cmp((0,4),(4,5))
cmp((3,0),(0,5))
cmp((3,4),(4,0))
cmp((0,4),(4,0))
cmp((0,0),(0,0))
def test_dot22scalar(): def test_dot22scalar():
## including does not seem to work for 'local_dot_to_dot22' and ## including does not seem to work for 'local_dot_to_dot22' and
...@@ -698,75 +704,102 @@ def test_dot22scalar(): ...@@ -698,75 +704,102 @@ def test_dot22scalar():
## TODO: exclude other optimizations in BlasOpt? ## TODO: exclude other optimizations in BlasOpt?
#m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize') #m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize')
#m = theano.compile.get_default_mode().including('BlasOpt', 'specialize') #m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
a=T.matrix()
b=T.matrix()
c=T.matrix()
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
for dtype1 in ['complex64', 'complex128']:
def cmp(a_shp, b_shp, c_shp, sqr_shp=(5,5)): a=T.matrix('a', dtype = dtype1)
av=rng.uniform(size=a_shp).astype(config.floatX) for dtype2 in ['complex64', 'complex128']:
bv=rng.uniform(size=b_shp).astype(config.floatX) b=T.matrix('b', dtype = dtype2)
cv=rng.uniform(size=c_shp).astype(config.floatX) for dtype3 in ['complex64', 'complex128']:
sv=rng.uniform(size=sqr_shp).astype(config.floatX) c=T.matrix('c', dtype = dtype3)
for dtype4 in ['complex64', 'complex128']:
if True: cst = theano.tensor.basic.constant(.2, dtype=dtype4)
f = theano.function([a,b],0.2*T.dot(a,b),mode=mode_blas_opt) cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo] def check_dot22scalar(func, len_topo_scalar=-1):
assert len(topo)==1 topo = func.maker.env.toposort()
f(av,bv) ops = [x.op for x in topo]
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1, dtype2)
if True: if dtype1 == dtype2 == dtype3 == dtype4_upcast:
f = theano.function([a,b,c],0.2*c*T.dot(a,b),mode=mode_blas_opt) if len_topo_scalar>0:
topo = f.maker.env.toposort() assert len(topo) == len_topo_scalar
assert _dot22scalar in [x.op for x in topo] assert _dot22scalar in ops, (dtype1, dtype2, dtype3, dtype4)
assert len(topo)==2 elif dtype1 == dtype2 == dtype4_upcast:
f(av,bv,cv) if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar
f = theano.function([a,b,c],c * 0.2*T.dot(a,b),mode=mode_blas_opt) assert _dot22scalar in ops, (dtype1, dtype2, dtype3, dtype4)
topo = f.maker.env.toposort() else:
assert _dot22scalar in [x.op for x in topo] # Currently there is a problem of optimization order
assert len(topo)==2 # The constant get upcasted to float64 before we try to merge it
f(av,bv,cv) # with the dot22 of float32. So this prevent the merge.
assert _dot22scalar in ops or _dot22 in ops, (dtype1, dtype2, dtype3, dtype4)
## Here, canonicalize also seems needed
## TODO: add only the optimizations needed? elif dtype1 == dtype2:
m2 = mode_blas_opt.including('canonicalize') assert _dot22 in ops, (dtype1, dtype2, dtype3, dtype4)
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m2) else:
topo = f.maker.env.toposort() assert T.dot in ops, (dtype1, dtype2, dtype3, dtype4)
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv) def cmp(a_shp, b_shp, c_shp, sqr_shp=(5,5)):
av=rng.uniform(size=a_shp).astype(dtype1)
f = theano.function([a,b,c],c * 0.2*a*T.dot(a,b),mode=m2) bv=rng.uniform(size=b_shp).astype(dtype2)
topo = f.maker.env.toposort() cv=rng.uniform(size=c_shp).astype(dtype3)
assert _dot22scalar in [x.op for x in topo] sv=rng.uniform(size=sqr_shp).astype(dtype1)
assert len(topo)==2
f(sv,sv,sv) if False:
f = theano.function([a,b],cst*T.dot(a,b),mode=mode_blas_opt)
f = theano.function([a,b,c],0.2*c *a*T.dot(a,b),mode=mode_blas_opt) topo = f.maker.env.toposort()
topo = f.maker.env.toposort() check_dot22scalar(f, 1)
#currently the canonizer don't always merge all Mul together...
# dot22scalar optimizer does not do a recursive search f(av,bv)
# therefore, it doesn't find potential matches of the scalar.
# TODO: combine with the 'canonicalization' that is part of the Gemm optimizer. if True:
# f = theano.function([a,b,c],cst*c*T.dot(a,b),mode=mode_blas_opt)
# assert _dot22scalar in [x.op for x in topo] topo = f.maker.env.toposort()
# assert len(topo)==2 check_dot22scalar(f, 2)
f(sv,sv,sv)
f(av,bv,cv)
f = theano.function([a,b,c],c * a*0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort() f = theano.function([a,b,c],c * cst*T.dot(a,b),mode=mode_blas_opt)
assert _dot22scalar in [x.op for x in topo] topo = f.maker.env.toposort()
assert len(topo)==2 check_dot22scalar(f, 2)
f(sv,sv,sv) f(av,bv,cv)
cmp((3,4),(4,5),(3,5)) ## Here, canonicalize also seems needed
cmp((0,4),(4,5),(0,5)) ## TODO: add only the optimizations needed?
cmp((3,0),(0,5),(3,5)) m2 = mode_blas_opt.including('canonicalize')
cmp((3,4),(4,0),(3,0),(0,0)) f = theano.function([a,b,c],cst2 *c * cst*T.dot(a,b),mode=m2)
cmp((0,4),(4,0),(0,0)) topo = f.maker.env.toposort()
cmp((0,0),(0,0),(0,0)) check_dot22scalar(f, 2)
f(av,bv,cv)
if dtype1 == dtype2 == dtype3:
f = theano.function([a,b,c],c * cst*a*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(sv,sv,sv)
f = theano.function([a,b,c],cst*c *a*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together...
# dot22scalar optimizer does not do a recursive search
# therefore, it doesn't find potential matches of the scalar.
# TODO: combine with the 'canonicalization' that is part of the Gemm optimizer.
#
# assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2
f(sv,sv,sv)
f = theano.function([a,b,c],c * a*cst*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
check_dot22scalar(f, 2)
f(sv,sv,sv)
cmp((3,4),(4,5),(3,5))
cmp((0,4),(4,5),(0,5))
cmp((3,0),(0,5),(3,5))
cmp((3,4),(4,0),(3,0),(0,0))
cmp((0,4),(4,0),(0,0))
cmp((0,0),(0,0),(0,0))
def test_dot_w_self(): def test_dot_w_self():
# This can trigger problems in the optimization because what would normally be a gemm must # This can trigger problems in the optimization because what would normally be a gemm must
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论