提交 0130b6ff authored 作者: James Bergstra's avatar James Bergstra

bugfixes to local_dot_to_gemm

上级 39743595
......@@ -194,6 +194,7 @@ mode = 'FAST_RUN'
#mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
mode = Mode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker(nice_errors=True))
mode = Mode(optimizer='fast_run', linker='c')
mode = Mode(optimizer='fast_run', linker='c|py')
print mod.pretty(mode=mode)
m = mod.make(mode=mode)
......
......@@ -279,9 +279,9 @@ class Gemm(GemmRelated):
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if bz != (False,False): raise ValueError(Gemm.E_rank, bz)
if bx != (False,False): raise ValueError(Gemm.E_rank, bx)
if by != (False,False): raise ValueError(Gemm.E_rank, by)
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type()
......@@ -359,7 +359,7 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert x.type in T.float_matrix_types #makes sure x is a matrix
assert _is_real_matrix(x)
assert y.type == x.type #makes sure y is a matrix
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
......@@ -412,7 +412,7 @@ _dot22 = Dot22()
def local_dot_to_dot22(node):
if node.op == T.dot:
x,y = node.inputs
if x.type in T.float_matrix_types and y.type == x.type:
if _is_real_matrix(x) and y.type == x.type:
return [_dot22(*node.inputs)]
else:
return False
......@@ -434,15 +434,19 @@ def _as_scalar(res):
else:
return None
def _is_real_matrix(res):
return res.type in T.float_matrix_types\
and res.broadcastable == (False, False)
def _as_isolated_scalar_times_matrix(res):
if _is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if sL is not None and R.type in T.float_matrix_types:
if (sL is not None) and _is_real_matrix(R):
return (sL, R)
if sR is not None and L.type in T.float_matrix_types:
if (sR is not None) and _is_real_matrix(L):
return (sR, L)
else:
scalars = []
......@@ -451,7 +455,7 @@ def _as_isolated_scalar_times_matrix(res):
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif input.type in T.float_matrix_types:
elif _is_real_matrix(input):
matrices.append(input)
else:
return None
......@@ -503,9 +507,9 @@ def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
def local_sub_to_gemm(node):
if node.op == T.sub:
L, R = node.inputs
if L.type not in T.float_matrix_types:
if not _is_real_matrix(L):
return False
if R.type not in T.float_matrix_types:
if not _is_real_matrix(R):
return False
tmp = _as_isolated_scalar_times_matrix(L)
......@@ -536,7 +540,10 @@ def local_add_to_gemm(node):
sM_list = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
sM_list.append(tmp if tmp is not None else (1.0,input))
if tmp:
sM_list.append(tmp)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
if len(sM_list) == 2:
sL, mL = sM_list[0]
......@@ -553,7 +560,6 @@ def local_add_to_gemm(node):
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))]
return rval
return False
register_specialize(local_add_to_gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论