提交 f159d9eb authored 作者: --global's avatar --global

Match broadcastable pattern of original variable

上级 cb08bc11
...@@ -401,6 +401,12 @@ def expand(tensor_var, size): ...@@ -401,6 +401,12 @@ def expand(tensor_var, size):
zeros_shape = [size + shapes[0]] + shapes[1:] zeros_shape = [size + shapes[0]] + shapes[1:]
empty = tensor.zeros(zeros_shape, empty = tensor.zeros(zeros_shape,
dtype=tensor_var.dtype) dtype=tensor_var.dtype)
# Make sure to reuse the broadcast pattern of the original tensor for
# every dimension but the first one.
broadcastable = (False,) + tensor_var.broadcastable[1:]
empty = tensor.patternbroadcast(empty, broadcastable)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论