提交 95deb922 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Preserve broadcastable dimensions in params_broadcast_shapes

上级 f191a07e
...@@ -36,8 +36,16 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True): ...@@ -36,8 +36,16 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
# We need this in order to use `len` # We need this in order to use `len`
param_shape = tuple(param_shape) param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
def max_bcast(x, y):
if getattr(x, "value", x) == 1:
return y
if getattr(y, "value", y) == 1:
return x
return max_fn(x, y)
rev_extra_dims = [ rev_extra_dims = [
max_fn(a, b) max_bcast(a, b)
for a, b in zip_longest(reversed(extras), rev_extra_dims, fillvalue=1) for a, b in zip_longest(reversed(extras), rev_extra_dims, fillvalue=1)
] ]
...@@ -84,7 +92,10 @@ def broadcast_params(params, ndims_params): ...@@ -84,7 +92,10 @@ def broadcast_params(params, ndims_params):
use_aesara = False use_aesara = False
param_shapes = [] param_shapes = []
for p in params: for p in params:
param_shape = p.shape param_shape = tuple(
1 if bcast else s
for s, bcast in zip(p.shape, getattr(p, "broadcastable", (False,) * p.ndim))
)
use_aesara |= isinstance(p, Variable) use_aesara |= isinstance(p, Variable)
param_shapes.append(param_shape) param_shapes.append(param_shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论