提交 f291b362 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1665 from nouiz/join

Remove Rebroadcast in the input of Join.
......@@ -3506,13 +3506,7 @@ class Join(Op):
except IndexError:
raise ValueError('Join argument "axis" is out of range'
' (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis)
for x in as_tensor_variable_args]
else:
# These unbroadcasts are for the gradient... not sure exactly
# why...
as_tensor_variable_args = [unbroadcast(x, *range(x.type.ndim))
for x in as_tensor_variable_args]
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(
......@@ -3590,7 +3584,11 @@ class Join(Op):
# If there is only one split, it might not be in a list.
if not isinstance(split_gz, list):
split_gz = [split_gz]
# Split.make_node isn't always able to infer the right
# broadcast. As the grad need to keep the information,
# readd it if needed.
split_gz = [patternbroadcast(g, t.broadcastable)
for t, g in zip(tensors, split_gz)]
rval = rval + split_gz
else:
# the output has integer type, so the gradient through it
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论