提交 0cc47636 authored 作者: James Bergstra's avatar James Bergstra

fix output broadcastable flags in GpuDot22, GpuGemm

上级 6f5a1844
......@@ -19,7 +19,9 @@ class GpuDot22(Op):
raise TypeError(x)
if y.type.ndim != 2:
raise TypeError(y)
return Apply(self, [x,y], [x.type()])
otype = CudaNdarrayType(
(x.type.broadcastable[0], y.type.broadcastable[1]))
return Apply(self, [x,y], [otype()])
def c_code_cache_version(self):
return (1,1)
......@@ -87,7 +89,9 @@ class GpuDot22Scalar(Op):
raise TypeError(y)
if not tensor.blas._as_scalar(a):
raise TypeError(a)
return Apply(self, [x,y,a], [x.type()])
otype = CudaNdarrayType(
(x.type.broadcastable[0], y.type.broadcastable[1]))
return Apply(self, [x,y,a], [otype()])
def c_code_cache_version(self):
return (1,1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论