提交 db8bf96d authored 作者: Frederic's avatar Frederic

make dot22 work with complex.

上级 e41ee606
...@@ -1218,9 +1218,10 @@ class Dot22(GemmRelated): ...@@ -1218,9 +1218,10 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
""" """
def make_node(self, x, y): def make_node(self, x, y):
if x.type.ndim != 2 or x.type.dtype not in ('float32', 'float64'): dtypes = ('float32', 'float64', 'complex64', 'complex128')
if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x) raise TypeError(x)
if y.type.ndim != 2 or y.type.dtype not in ('float32', 'float64'): if y.type.ndim != 2 or y.type.dtype not in dtypes:
raise TypeError(y) raise TypeError(y)
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
raise TypeError('dtype mismatch to Dot22') raise TypeError('dtype mismatch to Dot22')
...@@ -1301,7 +1302,7 @@ def local_dot_to_dot22(node): ...@@ -1301,7 +1302,7 @@ def local_dot_to_dot22(node):
_logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type) _logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type)
return return
if y.type.dtype.startswith('float'): if y.type.dtype.startswith('float') or y.type.dtype.startswith('complex'):
if x.ndim == 2 and y.ndim == 2: if x.ndim == 2 and y.ndim == 2:
#print "local_dot_to_dot22: MM" #print "local_dot_to_dot22: MM"
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
......
...@@ -674,17 +674,23 @@ def test_inplace1(): ...@@ -674,17 +674,23 @@ def test_inplace1():
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22(): def test_dot22():
a=T.matrix() for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
b=T.matrix() a=T.matrix(dtype = dtype1)
for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
b=T.matrix(dtype = dtype2)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt) f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
assert _dot22 in [x.op for x in topo] if dtype1 == dtype2:
assert _dot22 in [x.op for x in topo], (dtype1,dtype2)
else:
assert T.dot in [x.op for x in topo], (dtype1,dtype2)
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
av=rng.uniform(size=a_shp).astype(config.floatX) av=rng.uniform(size=a_shp).astype(dtype1)
bv=rng.uniform(size=b_shp).astype(config.floatX) bv=rng.uniform(size=b_shp).astype(dtype2)
f(av,bv) f(av,bv)
cmp((3,4),(4,5)) cmp((3,4),(4,5))
cmp((0,4),(4,5)) cmp((0,4),(4,5))
cmp((3,0),(0,5)) cmp((3,0),(0,5))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论