提交 896fdba3 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

connection pattern and grad method for several ops

上级 f90b7022
...@@ -3980,6 +3980,15 @@ class Subtensor(Op): ...@@ -3980,6 +3980,15 @@ class Subtensor(Op):
return ([IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] return ([IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)]
+ [DisconnectedType()()] * len(rest)) + [DisconnectedType()()] * len(rest))
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[1:]:
rval.append([False])
return rval
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list return type(self) == type(other) and self.idx_list == other.idx_list
...@@ -4636,7 +4645,7 @@ class IncSubtensor(Op): ...@@ -4636,7 +4645,7 @@ class IncSubtensor(Op):
rval = [ [True] ] rval = [ [True] ]
for ipt in node.inputs[1:]: for ipt in node.inputs[2:]:
rval.append([False]) rval.append([False])
return rval return rval
...@@ -4654,7 +4663,7 @@ class IncSubtensor(Op): ...@@ -4654,7 +4663,7 @@ class IncSubtensor(Op):
gx = g_output gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
return [gx, gy] + [None] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, n_splits, axis=0):
...@@ -4772,8 +4781,10 @@ class Split(Op): ...@@ -4772,8 +4781,10 @@ class Split(Op):
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
_, axis, _ = inputs _, axis, n = inputs
return [join(axis, *g_outputs), None, None] return [join(axis, *g_outputs),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -5041,6 +5052,9 @@ class Join(Op): ...@@ -5041,6 +5052,9 @@ class Join(Op):
""" """
gz, = grads gz, = grads
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
rval = [ grad_undefined(self, 0, axis) ]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype: if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype:
# assume that this is differentiable # assume that this is differentiable
split = Split(len(tensors)) split = Split(len(tensors))
...@@ -5049,10 +5063,14 @@ class Join(Op): ...@@ -5049,10 +5063,14 @@ class Join(Op):
# If there is only one split, it might not be in a list. # If there is only one split, it might not be in a list.
if not isinstance(split_gz, list): if not isinstance(split_gz, list):
split_gz = [split_gz] split_gz = [split_gz]
return [None] + split_gz
rval = rval + split_gz
else: else:
# assume that this isn't differentiable # the output has integer type, so the gradient through it
return [None] * (1 + len(tensors)) # is 0
rval = rval + [ tensor.zeros_like() for tensor in tensors ]
return rval
def _native_grad(self, axis_and_tensors, grads): def _native_grad(self, axis_and_tensors, grads):
"""WRITEME""" """WRITEME"""
...@@ -5781,9 +5799,21 @@ class ARange(Op): ...@@ -5781,9 +5799,21 @@ class ARange(Op):
step = step.item() step = step.item()
out[0] = numpy.arange(start, stop, step, dtype=self.dtype) out[0] = numpy.arange(start, stop, step, dtype=self.dtype)
def connection_pattern(self, node):
return [ [True], [False], [True] ]
def grad(self, inputs, grads): def grad(self, inputs, grads):
start, stop, step = inputs
gz, = grads gz, = grads
return [None] * len(inputs) # start and step affect the output values
# but the outputs are integers so there's
# no gradient through them
# stop does not affect the output values,
# just the output shape, so it is disconnected
return [ start.zeros_like(),
DisconnectedType()(),
step.zeros_like() ]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
...@@ -6082,11 +6112,21 @@ class AdvancedSubtensor1(Op): ...@@ -6082,11 +6112,21 @@ class AdvancedSubtensor1(Op):
out[0] = x.take(i, axis=0, out=o) out[0] = x.take(i, axis=0, out=o)
def connection_pattern(self, node):
rval = [ [True ] ]
for ipt in node.inputs[1:]:
rval.append([ False ])
return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
gz, = grads gz, = grads
assert len(inputs) == 2 assert len(inputs) == 2
rval1 = [advanced_inc_subtensor1(zeros_like(inputs[0]), gz, inputs[1])] rval1 = [advanced_inc_subtensor1(zeros_like(inputs[0]), gz, inputs[1])]
return rval1 + [None] * (len(inputs) - 1) return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -6185,6 +6225,16 @@ class AdvancedIncSubtensor1(Op): ...@@ -6185,6 +6225,16 @@ class AdvancedIncSubtensor1(Op):
return self.make_node(eval_points[0], eval_points[1], return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs *inputs[2:]).outputs
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[2:]:
rval.append([False])
return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
g_output, = grads g_output, = grads
x, y = inputs[:2] x, y = inputs[:2]
...@@ -6193,7 +6243,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -6193,7 +6243,7 @@ class AdvancedIncSubtensor1(Op):
gx = g_output gx = g_output
gy = advanced_subtensor1(g_output, *idx_list) gy = advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [None] * len(idx_list) return [gx, gy] + [ DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
...@@ -6282,12 +6332,22 @@ class AdvancedSubtensor(Op): ...@@ -6282,12 +6332,22 @@ class AdvancedSubtensor(Op):
# return # return
#raise NotImplementedError() #raise NotImplementedError()
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[1:]:
rval.append([False])
return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
gz, = grads gz, = grads
x = inputs[0] x = inputs[0]
rest = inputs[1:] rest = inputs[1:]
return [advanced_inc_subtensor(zeros_like(x), gz, return [advanced_inc_subtensor(zeros_like(x), gz,
*rest)] + [None] * len(rest) *rest)] + \
[DisconnectedType()()] * len(rest)
class AdvancedIncSubtensor(Op): class AdvancedIncSubtensor(Op):
...@@ -6372,13 +6432,23 @@ class AdvancedIncSubtensor(Op): ...@@ -6372,13 +6432,23 @@ class AdvancedIncSubtensor(Op):
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
return [ishapes[0]] return [ishapes[0]]
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[2:]:
rval.append([False])
return rval
def grad(self, inpt, output_gradients): def grad(self, inpt, output_gradients):
x, y = inpt[:2] x, y = inpt[:2]
idxs = inpt[2:] idxs = inpt[2:]
outgrad, = output_gradients outgrad, = output_gradients
d_x_wrt_C = outgrad d_x_wrt_C = outgrad
d_y_wrt_C = AdvancedSubtensor()(outgrad, *idxs) d_y_wrt_C = AdvancedSubtensor()(outgrad, *idxs)
return [d_x_wrt_C, d_y_wrt_C] + [None for _ in idxs] return [d_x_wrt_C, d_y_wrt_C] + \
[DisconnectedType()() for _ in idxs]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points[:2]: if None in eval_points[:2]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论