提交 50605d2c authored 作者: Frederic Bastien's avatar Frederic Bastien

optimized the python implementation of gemv and ger without scipy.

上级 acf20ff6
...@@ -205,9 +205,14 @@ class Gemv(Op): ...@@ -205,9 +205,14 @@ class Gemv(Op):
#out_storage[0][0] = gemv(alpha, A, x, beta, y, overwrite_y=self.inplace) #out_storage[0][0] = gemv(alpha, A, x, beta, y, overwrite_y=self.inplace)
out_storage[0][0] = gemv(alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True) out_storage[0][0] = gemv(alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True)
else: else:
out_storage[0][0] = numpy.asarray( out = numpy.dot(A, x)
beta * y + alpha * numpy.dot(A, x) if alpha != 1:
, dtype=y.dtype) out *= alpha
if beta != 1:
out += beta * y
else:
out += y
out_storage[0][0] = numpy.asarray(out, dtype=y.dtype)
gemv_no_inplace = Gemv(inplace=False) gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True) gemv_inplace = Gemv(inplace=True)
...@@ -276,7 +281,10 @@ class Ger(Op): ...@@ -276,7 +281,10 @@ class Ger(Op):
A = cA[0] A = cA[0]
else: else:
A = cA[0].copy() A = cA[0].copy()
A += calpha[0] * numpy.outer(cx[0], cy[0]) if calpha[0] != 1:
A += calpha[0] * numpy.outer(cx[0], cy[0])
else:
A += numpy.outer(cx[0], cy[0])
cZ[0] = A cZ[0] = A
#TODO: If this is currently an unofficial part of the thunk API, #TODO: If this is currently an unofficial part of the thunk API,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论