gemm passing aliasing tests

上级 6b6cee3b
......@@ -739,7 +739,7 @@ class t_gemm(unittest.TestCase):
Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, Z.T, 1.0)
gemm(Z, 1.0, A, transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
......@@ -749,7 +749,7 @@ class t_gemm(unittest.TestCase):
Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z.T, A, 1.0)
gemm(Z, 1.0, transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
......
......@@ -366,13 +366,16 @@ class Gemm(_Op):
nout=1
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z not unique in argument list'
E_z_uniq = 'argument z aliased to x or y'
debug = False
def __init__(self, *args, **kwargs):
_Op.__init__(self, *args, **kwargs)
z, a, x, y, b = self.inputs
if z in self.inputs[1:]:
raise ValueError(Gemm.E_z_uniq, self.inputs)
zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
def destroy_map(self):
return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论