提交 8d761b6c authored 作者: James Bergstra's avatar James Bergstra

blas - removed shape parameter from TensorType constructor calls

上级 18a962c9
...@@ -793,9 +793,8 @@ class Dot22(GemmRelated): ...@@ -793,9 +793,8 @@ class Dot22(GemmRelated):
raise TypeError(y) raise TypeError(y)
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
raise TypeError('dtype mismatch to Dot22') raise TypeError('dtype mismatch to Dot22')
out_shape = (x.type.shape[0], y.type.shape[1])
bz = [False, False] bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz, shape=out_shape)] outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs) return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )): def perform(self, node, (x, y), (z, )):
...@@ -904,9 +903,8 @@ class Dot22Scalar(GemmRelated): ...@@ -904,9 +903,8 @@ class Dot22Scalar(GemmRelated):
raise TypeError(scalar) raise TypeError(scalar)
if y.type.dtype != x.type.dtype and y.type.dtype != scalar.type.dtype: if y.type.dtype != x.type.dtype and y.type.dtype != scalar.type.dtype:
raise TypeError('dtype mismatch to Dot22Scalar') raise TypeError('dtype mismatch to Dot22Scalar')
out_shape = (x.type.shape[0], y.type.shape[1])
bz = [False, False] bz = [False, False]
outputs = [T.tensor(x.type.dtype, bz, shape=out_shape)] outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y,scalar], outputs) return Apply(self, [x,y,scalar], outputs)
def perform(self, node, (x, y, scalar), (z, )): def perform(self, node, (x, y, scalar), (z, )):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论