提交 47ac5f99 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6376 from nouiz/broadcast_corrmm

[FIX] Broadcast of corrmm
...@@ -1033,7 +1033,7 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM): ...@@ -1033,7 +1033,7 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM):
assert shape[1].ndim == 0 assert shape[1].ndim == 0
if self.unshared: if self.unshared:
broadcastable = [topgrad.type.broadcastable[0], False, False, broadcastable = [topgrad.type.broadcastable[1], False, False,
img.type.broadcastable[1], False, False] img.type.broadcastable[1], False, False]
else: else:
broadcastable = [topgrad.type.broadcastable[1], img.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[1], img.type.broadcastable[1],
......
...@@ -695,10 +695,10 @@ class CorrMM_gradWeights(BaseCorrMM): ...@@ -695,10 +695,10 @@ class CorrMM_gradWeights(BaseCorrMM):
height_width = [as_tensor_variable(shape[0]).astype('int64'), as_tensor_variable(shape[1]).astype('int64')] height_width = [as_tensor_variable(shape[0]).astype('int64'), as_tensor_variable(shape[1]).astype('int64')]
if self.unshared is True: if self.unshared is True:
broadcastable = [topgrad.type.broadcastable[0], False, False, broadcastable = [topgrad.type.broadcastable[1], False, False,
img.type.broadcastable[1], False, False] img.type.broadcastable[1], False, False]
else: else:
broadcastable = [topgrad.type.broadcastable[0], img.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[1], img.type.broadcastable[1],
False, False] False, False]
dtype = img.type.dtype dtype = img.type.dtype
return Apply(self, [img, topgrad] + height_width, return Apply(self, [img, topgrad] + height_width,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论