提交 00697549 authored 作者: James Bergstra's avatar James Bergstra

added optimization in tensor/blas to use dot22 for broadcasted tensors

上级 33becd6e
...@@ -571,6 +571,10 @@ def _is_real_matrix(res): ...@@ -571,6 +571,10 @@ def _is_real_matrix(res):
and res.type.ndim == 2 \ and res.type.ndim == 2 \
and res.type.broadcastable[0] == False \ and res.type.broadcastable[0] == False \
and res.type.broadcastable[1] == False #cope with tuple vs. list and res.type.broadcastable[1] == False #cope with tuple vs. list
def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \
and res.type.broadcastable[0] == False
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
...@@ -805,13 +809,13 @@ class Dot22(GemmRelated): ...@@ -805,13 +809,13 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
""" """
def make_node(self, x, y): def make_node(self, x, y):
if not _is_real_matrix(x): if x.type.ndim != 2 or x.type.dtype not in ('float32', 'float64'):
raise TypeError(x) raise TypeError(x)
if not _is_real_matrix(x): if y.type.ndim != 2 or y.type.dtype not in ('float32', 'float64'):
raise TypeError(y) raise TypeError(y)
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
raise TypeError('dtype mismatch to Dot22') raise TypeError('dtype mismatch to Dot22')
bz = [False, False] bz = (x.type.broadcastable[0], y.type.broadcastable[1])
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs) return Apply(self, [x,y], outputs)
...@@ -870,14 +874,26 @@ _dot22 = Dot22() ...@@ -870,14 +874,26 @@ _dot22 = Dot22()
@local_optimizer([T.dot]) @local_optimizer([T.dot])
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
if node.op == T.dot: if node.op != T.dot:
x,y = node.inputs return
if _is_real_matrix(x) and _is_real_matrix(y) and y.type.dtype == x.type.dtype:
x,y = node.inputs
if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match
info('Not optimizing dot with inputs', x, y, x.type, y.type)
return
print 'asdfasdf'
if y.type.dtype.startswith('float'):
if _is_real_matrix(x) and _is_real_matrix(y):
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
else: if _is_real_matrix(x) and _is_real_vector(y):
info('Not optimizing dot with inputs', x, y, x.type, y.type) return [_dot22(x, y.dimshuffle(0,'x')).dimshuffle(0)]
else: if _is_real_vector(x) and _is_real_matrix(y):
return False return [_dot22(x.dimshuffle('x',0), y).dimshuffle(1)]
if _is_real_vector(x) and _is_real_vector(x):
return [_dot22(x.dimshuffle('x',0), y.dimshuffle(0,'x')).dimshuffle()]
info('Not optimizing dot with inputs', x, y, x.type, y.type)
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
def local_inplace_gemm(node): def local_inplace_gemm(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论