gemm passing aliasing tests

上级 6b6cee3b
...@@ -739,7 +739,7 @@ class t_gemm(unittest.TestCase): ...@@ -739,7 +739,7 @@ class t_gemm(unittest.TestCase):
Z = astensor(self.rand(2,2)) Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2)) A = astensor(self.rand(2,2))
try: try:
gemm(Z, 1.0, A, Z.T, 1.0) gemm(Z, 1.0, A, transpose_inplace(Z), 1.0)
except ValueError, e: except ValueError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
...@@ -749,7 +749,7 @@ class t_gemm(unittest.TestCase): ...@@ -749,7 +749,7 @@ class t_gemm(unittest.TestCase):
Z = astensor(self.rand(2,2)) Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2)) A = astensor(self.rand(2,2))
try: try:
gemm(Z, 1.0, Z.T, A, 1.0) gemm(Z, 1.0, transpose_inplace(Z), A, 1.0)
except ValueError, e: except ValueError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
......
...@@ -366,13 +366,16 @@ class Gemm(_Op): ...@@ -366,13 +366,16 @@ class Gemm(_Op):
nout=1 nout=1
E_rank = 'gemm only works for rank 2' E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument' E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z not unique in argument list' E_z_uniq = 'argument z aliased to x or y'
debug = False debug = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
_Op.__init__(self, *args, **kwargs) _Op.__init__(self, *args, **kwargs)
z, a, x, y, b = self.inputs z, a, x, y, b = self.inputs
if z in self.inputs[1:]: zr, xr, yr = [set(gof.view_roots(i)) for i in z,x,y]
raise ValueError(Gemm.E_z_uniq, self.inputs) if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y))
def destroy_map(self): def destroy_map(self):
return {self.out:[self.inputs[0]]} return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb): def propagate_broadcastable(self, bz, ba, bx, by, bb):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论