提交 6b42efe1 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

R operator for the join op

上级 ae2df05d
......@@ -4011,6 +4011,11 @@ class Join(Op):
out[0] = theano._asarray(numpy.concatenate(tensors, axis = axis),
dtype=node.outputs[0].type.dtype)
def R_op(self, inputs, eval_points):
if None in eval_points[1:]:
return [None]
return self.make_node(inputs[0], *eval_points[1:]).outputs
def grad(self, axis_and_tensors, grads):
""" The gradient wrt a join op is a `Split`, used to partition the gradient along the
`axis` which was used for joining.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论