added arg check to gemm.__init__

上级 b11405ec
......@@ -715,6 +715,11 @@ class Gemm(_Op):
nout=1
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
def __init__(self, *args, **kwargs):
z, a, x, y, b = args
if z in args[1:]:
raise ValueError('argument z not unique in argument list', args)
_Op.__init__(self, *args, **kwargs)
def destroy_map(self):
return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论