提交 a87126d4 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small change.

上级 ed509d18
...@@ -149,7 +149,7 @@ class Gemv(Op): ...@@ -149,7 +149,7 @@ class Gemv(Op):
A is matrix A is matrix
x, y are vectors x, y are vectors
alpha, beta are scalars alpha, beta are scalars
output is a vector that can be inplace on y
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.inplace=inplace self.inplace=inplace
...@@ -1362,13 +1362,13 @@ def local_gemm_to_ger(node): ...@@ -1362,13 +1362,13 @@ def local_gemm_to_ger(node):
return return
if bval == 1: # best case a natural GER if bval == 1: # best case a natural GER
rval = Ger(destructive=False)(z, a, xv, yv) rval = ger(z, a, xv, yv)
return [rval] return [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = T.alloc( zeros = T.alloc(
numpy.asarray(0, dtype=x.dtype), numpy.asarray(0, dtype=x.dtype),
x.shape[0], y.shape[1]) x.shape[0], y.shape[1])
rval = Ger(destructive=False)(zeros, a, xv, yv) rval = ger(zeros, a, xv, yv)
return [rval] return [rval]
else: else:
# if bval is another constant, then z is being usefully # if bval is another constant, then z is being usefully
...@@ -1391,7 +1391,7 @@ def local_dot22_to_ger(node): ...@@ -1391,7 +1391,7 @@ def local_dot22_to_ger(node):
one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype)) one = T.as_tensor_variable(numpy.asarray(1, dtype=x.dtype))
zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0], y.shape[1]) zeros = T.alloc(numpy.asarray(0, dtype=x.dtype), x.shape[0], y.shape[1])
rval = Ger(destructive=False)(zeros, one, xv, yv) rval = ger(zeros, one, xv, yv)
return [rval] return [rval]
################################# #################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论