提交 d011f68e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make it easier to use test values with Rop and optimizations

In particular, use op.__call__ rather than make_node, so the tag.test_value gets computed automatically. The weird syntax with **dict(return_list=True) is because in older Python versions (including 2.4), you cannot pass a named argument (return_list=True) after an expanded list of unnamed arguments (*args). You can, however, pass an expanded dictionary of named argument (**kwargs).
上级 77f61060
...@@ -2579,7 +2579,7 @@ class Alloc(gof.Op): ...@@ -2579,7 +2579,7 @@ class Alloc(gof.Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def do_constant_folding(self, node): def do_constant_folding(self, node):
if not getattr(node.outputs[0], 'clients', []): if not getattr(node.outputs[0], 'clients', []):
...@@ -3275,7 +3275,7 @@ class Rebroadcast(Op): ...@@ -3275,7 +3275,7 @@ class Rebroadcast(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(*eval_points).outputs return self(*eval_points, **dict(return_list=True))
def addbroadcast(x, *axes): def addbroadcast(x, *axes):
...@@ -3805,7 +3805,7 @@ class Reshape(Op): ...@@ -3805,7 +3805,7 @@ class Reshape(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
# inputs[1] can contain at most one value of '-1', meaning the actual # inputs[1] can contain at most one value of '-1', meaning the actual
...@@ -4600,7 +4600,7 @@ class Dot(Op): ...@@ -4600,7 +4600,7 @@ class Dot(Op):
eval_point_values = [ev0, ev1] eval_point_values = [ev0, ev1]
for i in xrange(2): for i in xrange(2):
if eval_point_values[i] and \ if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape: input_values[i].shape != eval_point_values[i].shape:
raise ValueError('input ' + str(i) + ' and eval_point ' + raise ValueError('input ' + str(i) + ' and eval_point ' +
str(i) + ' to Dot.R_op ' str(i) + ' to Dot.R_op '
......
...@@ -273,7 +273,7 @@ class DimShuffle(Op): ...@@ -273,7 +273,7 @@ class DimShuffle(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points: if None in eval_points:
return [None] return [None]
return self.make_node(*eval_points).outputs return self(*eval_points, **dict(return_list=True))
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
input, = inp input, = inp
...@@ -616,7 +616,7 @@ class Elemwise(Op): ...@@ -616,7 +616,7 @@ class Elemwise(Op):
return self.name return self.name
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
outs = self.make_node(*inputs).outputs outs = self(*inputs, **dict(return_list=True))
rval = [None for x in outs] rval = [None for x in outs]
# For each output # For each output
for idx, out in enumerate(outs): for idx, out in enumerate(outs):
...@@ -1882,7 +1882,7 @@ class Sum(CAReduceDtype): ...@@ -1882,7 +1882,7 @@ class Sum(CAReduceDtype):
# part of self # part of self
if None in eval_points: if None in eval_points:
return [None] return [None]
return self.make_node(*eval_points).outputs return self(*eval_points, **dict(return_list=True))
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
......
...@@ -263,10 +263,11 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -263,10 +263,11 @@ def inplace_elemwise_optimizer_op(OP):
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))])) for i in xrange(len(node.outputs))]))
new = OP(new_scal, inplace_pattern).make_node( new_outputs = OP(new_scal, inplace_pattern)(
*node.inputs) *node.inputs, **dict(return_list=True))
new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new.outputs): for r, new_r in zip(node.outputs, new_outputs):
fgraph.replace(r, new_r, fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer") reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1 nb_change_no_validate += 1
...@@ -284,7 +285,7 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -284,7 +285,7 @@ def inplace_elemwise_optimizer_op(OP):
fgraph.revert(chk) fgraph.revert(chk)
continue continue
candidate_inputs.remove(candidate_input) candidate_inputs.remove(candidate_input)
node = new node = new_node
baseline = inplace_pattern baseline = inplace_pattern
break break
......
...@@ -883,7 +883,7 @@ class Subtensor(Op): ...@@ -883,7 +883,7 @@ class Subtensor(Op):
# (they should be defaulted to zeros_like by the global R_op) # (they should be defaulted to zeros_like by the global R_op)
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self(eval_points[0], *inputs[1:], **dict(return_list=True))
class SubtensorPrinter: class SubtensorPrinter:
...@@ -1414,8 +1414,8 @@ class IncSubtensor(Op): ...@@ -1414,8 +1414,8 @@ class IncSubtensor(Op):
return [None] return [None]
# Again we ignore eval points for indices because incsubtensor is # Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those # not differentiable wrt to those
return self.make_node(eval_points[0], eval_points[1], return self(eval_points[0], eval_points[1], *inputs[2:],
*inputs[2:]).outputs **dict(return_list=True))
def connection_pattern(self, node): def connection_pattern(self, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论