提交 39743595 authored 作者: james@crane's avatar james@crane

stricter check on dot22

上级 0de78646
...@@ -411,6 +411,8 @@ _dot22 = Dot22() ...@@ -411,6 +411,8 @@ _dot22 = Dot22()
@local_optimizer([T.dot]) @local_optimizer([T.dot])
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
if node.op == T.dot: if node.op == T.dot:
x,y = node.inputs
if x.type in T.float_matrix_types and y.type == x.type:
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
else: else:
return False return False
...@@ -536,8 +538,7 @@ def local_add_to_gemm(node): ...@@ -536,8 +538,7 @@ def local_add_to_gemm(node):
tmp = _as_isolated_scalar_times_matrix(input) tmp = _as_isolated_scalar_times_matrix(input)
sM_list.append(tmp if tmp is not None else (1.0,input)) sM_list.append(tmp if tmp is not None else (1.0,input))
#print sM_list if len(sM_list) == 2:
if len(node.inputs) == 2:
sL, mL = sM_list[0] sL, mL = sM_list[0]
sR, mR = sM_list[1] sR, mR = sM_list[1]
return beta_L_plus_alpha_M(sL, mL, sR, mR) return beta_L_plus_alpha_M(sL, mL, sR, mR)
......
...@@ -292,3 +292,12 @@ class T_gemm_opt(TestCase): ...@@ -292,3 +292,12 @@ class T_gemm_opt(TestCase):
self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)]) self.just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)]) self.just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
def test_vector_stuff(self):
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector()
f = function([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
print f.maker.env.nodes
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论