提交 ca9845f9 authored 作者: Frederic's avatar Frederic

Remove Rebroadcast in the input of Join.

Add it only when needed in the grad.
上级 7df21906
...@@ -3506,13 +3506,7 @@ class Join(Op): ...@@ -3506,13 +3506,7 @@ class Join(Op):
except IndexError: except IndexError:
raise ValueError('Join argument "axis" is out of range' raise ValueError('Join argument "axis" is out of range'
' (given input dimensions)') ' (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis)
for x in as_tensor_variable_args]
else: else:
# These unbroadcasts are for the gradient... not sure exactly
# why...
as_tensor_variable_args = [unbroadcast(x, *range(x.type.ndim))
for x in as_tensor_variable_args]
# When the axis may vary, no dimension can be guaranteed to be # When the axis may vary, no dimension can be guaranteed to be
# broadcastable. # broadcastable.
bcastable = [False] * len( bcastable = [False] * len(
...@@ -3590,7 +3584,11 @@ class Join(Op): ...@@ -3590,7 +3584,11 @@ class Join(Op):
# If there is only one split, it might not be in a list. # If there is only one split, it might not be in a list.
if not isinstance(split_gz, list): if not isinstance(split_gz, list):
split_gz = [split_gz] split_gz = [split_gz]
# Split.make_node isn't always able to infer the right
# broadcast. As the grad need to keep the information,
# readd it if needed.
split_gz = [patternbroadcast(g, t.broadcastable)
for t, g in zip(tensors, split_gz)]
rval = rval + split_gz rval = rval + split_gz
else: else:
# the output has integer type, so the gradient through it # the output has integer type, so the gradient through it
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论