提交 896fdba3 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

connection pattern and grad method for several ops

上级 f90b7022
......@@ -3980,6 +3980,15 @@ class Subtensor(Op):
return ([IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)]
+ [DisconnectedType()()] * len(rest))
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[1:]:
rval.append([False])
return rval
def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list
......@@ -4636,7 +4645,7 @@ class IncSubtensor(Op):
rval = [ [True] ]
for ipt in node.inputs[1:]:
for ipt in node.inputs[2:]:
rval.append([False])
return rval
......@@ -4654,7 +4663,7 @@ class IncSubtensor(Op):
gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
return [gx, gy] + [None] * len(idx_list)
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
def split(x, splits_size, n_splits, axis=0):
......@@ -4772,8 +4781,10 @@ class Split(Op):
def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x."""
_, axis, _ = inputs
return [join(axis, *g_outputs), None, None]
_, axis, n = inputs
return [join(axis, *g_outputs),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -5041,6 +5052,9 @@ class Join(Op):
"""
gz, = grads
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
rval = [ grad_undefined(self, 0, axis) ]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype:
# assume that this is differentiable
split = Split(len(tensors))
......@@ -5049,10 +5063,14 @@ class Join(Op):
# If there is only one split, it might not be in a list.
if not isinstance(split_gz, list):
split_gz = [split_gz]
return [None] + split_gz
rval = rval + split_gz
else:
# assume that this isn't differentiable
return [None] * (1 + len(tensors))
# the output has integer type, so the gradient through it
# is 0
rval = rval + [ tensor.zeros_like() for tensor in tensors ]
return rval
def _native_grad(self, axis_and_tensors, grads):
"""WRITEME"""
......@@ -5781,9 +5799,21 @@ class ARange(Op):
step = step.item()
out[0] = numpy.arange(start, stop, step, dtype=self.dtype)
def connection_pattern(self, node):
return [ [True], [False], [True] ]
def grad(self, inputs, grads):
start, stop, step = inputs
gz, = grads
return [None] * len(inputs)
# start and step affect the output values
# but the outputs are integers so there's
# no gradient through them
# stop does not affect the output values,
# just the output shape, so it is disconnected
return [ start.zeros_like(),
DisconnectedType()(),
step.zeros_like() ]
def R_op(self, inputs, eval_points):
return [None]
......@@ -6082,11 +6112,21 @@ class AdvancedSubtensor1(Op):
out[0] = x.take(i, axis=0, out=o)
def connection_pattern(self, node):
rval = [ [True ] ]
for ipt in node.inputs[1:]:
rval.append([ False ])
return rval
def grad(self, inputs, grads):
gz, = grads
assert len(inputs) == 2
rval1 = [advanced_inc_subtensor1(zeros_like(inputs[0]), gz, inputs[1])]
return rval1 + [None] * (len(inputs) - 1)
return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -6185,6 +6225,16 @@ class AdvancedIncSubtensor1(Op):
return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[2:]:
rval.append([False])
return rval
def grad(self, inputs, grads):
g_output, = grads
x, y = inputs[:2]
......@@ -6193,7 +6243,7 @@ class AdvancedIncSubtensor1(Op):
gx = g_output
gy = advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [None] * len(idx_list)
return [gx, gy] + [ DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
......@@ -6282,12 +6332,22 @@ class AdvancedSubtensor(Op):
# return
#raise NotImplementedError()
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[1:]:
rval.append([False])
return rval
def grad(self, inputs, grads):
gz, = grads
x = inputs[0]
rest = inputs[1:]
return [advanced_inc_subtensor(zeros_like(x), gz,
*rest)] + [None] * len(rest)
*rest)] + \
[DisconnectedType()()] * len(rest)
class AdvancedIncSubtensor(Op):
......@@ -6372,13 +6432,23 @@ class AdvancedIncSubtensor(Op):
def infer_shape(self, node, ishapes):
return [ishapes[0]]
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[2:]:
rval.append([False])
return rval
def grad(self, inpt, output_gradients):
x, y = inpt[:2]
idxs = inpt[2:]
outgrad, = output_gradients
d_x_wrt_C = outgrad
d_y_wrt_C = AdvancedSubtensor()(outgrad, *idxs)
return [d_x_wrt_C, d_y_wrt_C] + [None for _ in idxs]
return [d_x_wrt_C, d_y_wrt_C] + \
[DisconnectedType()() for _ in idxs]
def R_op(self, inputs, eval_points):
if None in eval_points[:2]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论