提交 62a7e19a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add infer_shape to the blas ops.

上级 480ca3ad
......@@ -473,6 +473,9 @@ class Gemv(Op):
out += y
out_storage[0][0] = numpy.asarray(out, dtype=y.dtype)
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True)
# For the user interface. Opt will make them inplace later
......@@ -540,6 +543,9 @@ class Ger(Op):
A += numpy.outer(cx, cy)
cZ[0] = A
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
ger = Ger(destructive=False)
ger_destructive = Ger(destructive=True)
......@@ -1001,6 +1007,8 @@ class Gemm(GemmRelated):
E_mixed = 'gemm requires matching dtypes'
E_float = 'gemm requires floating-point dtypes'
__props__ = ('inplace',)
def __init__(self, inplace):
self.inplace = inplace
if self.inplace:
......@@ -1009,13 +1017,6 @@ class Gemm(GemmRelated):
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __eq__(self, other):
return (type(self) == type(other) and
self.inplace == other.inplace)
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
if self.inplace:
inplace_str = 'inplace'
......@@ -1124,6 +1125,9 @@ class Gemm(GemmRelated):
z += a * numpy.dot(x, y)
zout[0] = z
def infer_shape(self, node, input_shapes):
return [inputs_shapes[0]]
setup_z_Nz_Sz_inplace = """
if (%(_zout)s != %(_z)s)
{
......@@ -1747,8 +1751,8 @@ class Dot22(GemmRelated):
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return self.__class__.__name__
def infer_shape(self, node, input_shapes):
return [[inputs_shapes[0][0], inputs_shapes[1][1]]]
setup_z_Nz_Sz = """
if ((NULL == %(_zout)s)
......@@ -2018,8 +2022,8 @@ class Dot22Scalar(GemmRelated):
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return self.__class__.__name__
def infer_shape(self, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论