提交 24d00ecd authored 作者: James Bergstra's avatar James Bergstra

fixed dot22 optimization for types with None in shape

上级 96a6a869
...@@ -18,6 +18,7 @@ from theano import compile #to register the optimizer built by this file ...@@ -18,6 +18,7 @@ from theano import compile #to register the optimizer built by this file
from theano.tensor.blas_headers import cblas_header_text, blas_header_text from theano.tensor.blas_headers import cblas_header_text, blas_header_text
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
_logger.setLevel(logging.INFO)
def debug(*msg): _logger.debug(' '.join(str(m) for m in msg)) def debug(*msg): _logger.debug(' '.join(str(m) for m in msg))
def info(*msg): _logger.info(' '.join(str(m) for m in msg)) def info(*msg): _logger.info(' '.join(str(m) for m in msg))
def warn(*msg): _logger.warn(' '.join(str(m) for m in msg)) def warn(*msg): _logger.warn(' '.join(str(m) for m in msg))
...@@ -604,10 +605,15 @@ class Dot22(GemmRelated): ...@@ -604,10 +605,15 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
""" """
def make_node(self, x, y): def make_node(self, x, y):
assert _is_real_matrix(x) if not _is_real_matrix(x):
assert y.type == x.type #makes sure y is a matrix raise TypeError(x)
if not _is_real_matrix(x):
raise TypeError(y)
if y.type.dtype != x.type.dtype:
raise TypeError('dtype mismatch to Dot22')
out_shape = (x.type.shape[0], y.type.shape[1])
bz = [False, False] bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz, shape=out_shape)]
return Apply(self, [x,y], outputs) return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )): def perform(self, node, (x, y), (z, )):
...@@ -660,10 +666,10 @@ _dot22 = Dot22() ...@@ -660,10 +666,10 @@ _dot22 = Dot22()
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
if node.op == T.dot: if node.op == T.dot:
x,y = node.inputs x,y = node.inputs
if _is_real_matrix(x) and y.type == x.type: if _is_real_matrix(x) and _is_real_matrix(y) and y.type.dtype == x.type.dtype:
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
else: else:
info('Not optimizing dot with inputs', x, y) info('Not optimizing dot with inputs', x, y, x.type, y.type)
else: else:
return False return False
register_specialize(local_dot_to_dot22) register_specialize(local_dot_to_dot22)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论