提交 0130b6ff authored 作者: James Bergstra's avatar James Bergstra

bugfixes to local_dot_to_gemm

上级 39743595
...@@ -194,6 +194,7 @@ mode = 'FAST_RUN' ...@@ -194,6 +194,7 @@ mode = 'FAST_RUN'
#mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker()) #mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
mode = Mode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker(nice_errors=True)) mode = Mode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker(nice_errors=True))
mode = Mode(optimizer='fast_run', linker='c') mode = Mode(optimizer='fast_run', linker='c')
mode = Mode(optimizer='fast_run', linker='c|py')
print mod.pretty(mode=mode) print mod.pretty(mode=mode)
m = mod.make(mode=mode) m = mod.make(mode=mode)
......
...@@ -279,9 +279,9 @@ class Gemm(GemmRelated): ...@@ -279,9 +279,9 @@ class Gemm(GemmRelated):
if zr.intersection(yr): if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y)) raise ValueError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs] bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz)) if bz != (False,False): raise ValueError(Gemm.E_rank, bz)
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx)) if bx != (False,False): raise ValueError(Gemm.E_rank, bx)
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by)) if by != (False,False): raise ValueError(Gemm.E_rank, by)
if len(ba): raise ValueError(Gemm.E_scalar, ba) if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb) if len(bb): raise ValueError(Gemm.E_scalar, bb)
output = z.type() output = z.type()
...@@ -359,7 +359,7 @@ class Dot22(GemmRelated): ...@@ -359,7 +359,7 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
""" """
def make_node(self, x, y): def make_node(self, x, y):
assert x.type in T.float_matrix_types #makes sure x is a matrix assert _is_real_matrix(x)
assert y.type == x.type #makes sure y is a matrix assert y.type == x.type #makes sure y is a matrix
bz = [x.type.broadcastable[0], y.type.broadcastable[1]] bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz)]
...@@ -412,7 +412,7 @@ _dot22 = Dot22() ...@@ -412,7 +412,7 @@ _dot22 = Dot22()
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
if node.op == T.dot: if node.op == T.dot:
x,y = node.inputs x,y = node.inputs
if x.type in T.float_matrix_types and y.type == x.type: if _is_real_matrix(x) and y.type == x.type:
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
else: else:
return False return False
...@@ -434,15 +434,19 @@ def _as_scalar(res): ...@@ -434,15 +434,19 @@ def _as_scalar(res):
else: else:
return None return None
def _is_real_matrix(res):
return res.type in T.float_matrix_types\
and res.broadcastable == (False, False)
def _as_isolated_scalar_times_matrix(res): def _as_isolated_scalar_times_matrix(res):
if _is_a(res, T.mul, 1): if _is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2: if len(res.owner.inputs) == 2:
L, R = res.owner.inputs L, R = res.owner.inputs
sL = _as_scalar(L) sL = _as_scalar(L)
sR = _as_scalar(R) sR = _as_scalar(R)
if sL is not None and R.type in T.float_matrix_types: if (sL is not None) and _is_real_matrix(R):
return (sL, R) return (sL, R)
if sR is not None and L.type in T.float_matrix_types: if (sR is not None) and _is_real_matrix(L):
return (sR, L) return (sR, L)
else: else:
scalars = [] scalars = []
...@@ -451,7 +455,7 @@ def _as_isolated_scalar_times_matrix(res): ...@@ -451,7 +455,7 @@ def _as_isolated_scalar_times_matrix(res):
scalar_input = _as_scalar(input) scalar_input = _as_scalar(input)
if scalar_input is not None: if scalar_input is not None:
scalars.append(scalar_input) scalars.append(scalar_input)
elif input.type in T.float_matrix_types: elif _is_real_matrix(input):
matrices.append(input) matrices.append(input)
else: else:
return None return None
...@@ -503,9 +507,9 @@ def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -503,9 +507,9 @@ def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
def local_sub_to_gemm(node): def local_sub_to_gemm(node):
if node.op == T.sub: if node.op == T.sub:
L, R = node.inputs L, R = node.inputs
if L.type not in T.float_matrix_types: if not _is_real_matrix(L):
return False return False
if R.type not in T.float_matrix_types: if not _is_real_matrix(R):
return False return False
tmp = _as_isolated_scalar_times_matrix(L) tmp = _as_isolated_scalar_times_matrix(L)
...@@ -536,7 +540,10 @@ def local_add_to_gemm(node): ...@@ -536,7 +540,10 @@ def local_add_to_gemm(node):
sM_list = [] sM_list = []
for input in node.inputs: for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input) tmp = _as_isolated_scalar_times_matrix(input)
sM_list.append(tmp if tmp is not None else (1.0,input)) if tmp:
sM_list.append(tmp)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
if len(sM_list) == 2: if len(sM_list) == 2:
sL, mL = sM_list[0] sL, mL = sM_list[0]
...@@ -553,7 +560,6 @@ def local_add_to_gemm(node): ...@@ -553,7 +560,6 @@ def local_add_to_gemm(node):
inputs_without_ij = \ inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)] [input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))] return [T.add( *(inputs_without_ij + rval))]
return rval
return False return False
register_specialize(local_add_to_gemm) register_specialize(local_add_to_gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论