提交 a4a54be6 authored 作者: Frederic Bastien's avatar Frederic Bastien

merge backport.

......@@ -1857,6 +1857,47 @@ class Split(Op):
"""Join the gradients along the axis that was used to split x."""
return [join(axis, *g_outputs), None, None]
class Rebroadcast(Op):
"""
Change the input's broadcastable fields in
some predetermined way.
e.g.: Rebroadcast((0, True), (1, False))(x)
would make x broadcastable in axis 0
and not broadcastable in axis 1
See also the unbroadcast function.
"""
view_map = {0: [0]}
def __init__(self, *axis):
self.axis = dict(axis)
def make_node(self, x):
t = TensorType(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)])
return Apply(self, [x], [t()])
def perform(self, node, (x, ), (out, )):
for axis, value in self.axis.iteritems():
if value and x.shape[axis] != 1:
raise ValueError('Dimension %s in Rebroadcast\'s input was supposed to be 1 (got %s instead)' % (axis, x.shape[axis]))
out[0] = x
def grad(self, (x, ), (gz,)):
# restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis]) for axis, value in self.axis.iteritems()])(gz),
def addbroadcast(x, *axes):
"""
Make the input broadcastable in the specified axes.
"""
return Rebroadcast(*[(axis, True) for axis in axes])(x)
def unbroadcast(x, *axes):
"""
Make the input impossible to broadcast in the specified axes.
"""
return Rebroadcast(*[(axis, False) for axis in axes])(x)
class Join(Op):
"""
Concatenate multiple `TensorVariable`s along some axis.
......@@ -1919,6 +1960,9 @@ class Join(Op):
bcastable[axis] = False
except IndexError, e:
raise ValueError('Join argument "axis" is out of range (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis) for x in as_tensor_variable_args]
else:
as_tensor_variable_args = [unbroadcast(x, *range(x.type.ndim)) for x in as_tensor_variable_args]
inputs = [as_tensor_variable(axis)] + as_tensor_variable_args
if inputs[0].type not in int_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论