提交 0350b33a authored 作者: khaotik's avatar khaotik

microrefinements

- use `def` syntax for flake8 syntax issue - cleaner condition - use assert instead of raise for internal state
上级 90ba4b94
...@@ -284,13 +284,13 @@ class OpFromGraph(gof.Op): ...@@ -284,13 +284,13 @@ class OpFromGraph(gof.Op):
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
cond = int(lop_overrides == 'default') * 2 + int(grad_overrides == 'default') if lop_overrides != 'default':
if cond == 0: if grad_overrides != 'default':
raise ValueError('lop_overrides and grad_overrides are mutually exclusive') raise ValueError('lop_overrides and grad_overrides are mutually exclusive')
elif cond == 1: else:
self.set_lop_overrides(lop_overrides) self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop' self._lop_type = 'lop'
elif cond == 2: elif grad_overrides != 'default':
self.set_lop_overrides(grad_overrides) self.set_lop_overrides(grad_overrides)
self._lop_type = 'grad' self._lop_type = 'grad'
else: else:
...@@ -329,6 +329,8 @@ class OpFromGraph(gof.Op): ...@@ -329,6 +329,8 @@ class OpFromGraph(gof.Op):
if isinstance(lop_op, OpFromGraph): if isinstance(lop_op, OpFromGraph):
if self._lop_op_is_cached: if self._lop_op_is_cached:
return return
assert self._lop_type in ['lop', 'grad'],\
self.LOP_TYPE_ERR_MSG % self._lop_type
if self._lop_type == 'grad': if self._lop_type == 'grad':
needed_ninps = inp_len + len(local_outputs) needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs) ninps = len(lop_op.local_inputs)
...@@ -336,7 +338,8 @@ class OpFromGraph(gof.Op): ...@@ -336,7 +338,8 @@ class OpFromGraph(gof.Op):
raise ValueError( raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable # make a wrapper callable
lop_op = lambda inps, grads: self._lop_op(*(inps + grads)) # flake8: noqa def lop_op(inps, grads):
return self._lop_op(*(inps + grads))
elif self._lop_type == 'lop': elif self._lop_type == 'lop':
# OfG can be directly used in L_op format # OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs) needed_ninps = inp_len + 2 * len(local_outputs)
...@@ -348,8 +351,6 @@ class OpFromGraph(gof.Op): ...@@ -348,8 +351,6 @@ class OpFromGraph(gof.Op):
self._lop_op_stypes_l = [None] * inp_len self._lop_op_stypes_l = [None] * inp_len
self._lop_op.kwargs['on_unused_input'] = 'ignore' self._lop_op.kwargs['on_unused_input'] = 'ignore'
return return
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
output_grads = [out_t() for out_t in self.output_types] output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial( fn_grad = partial(
...@@ -360,12 +361,12 @@ class OpFromGraph(gof.Op): ...@@ -360,12 +361,12 @@ class OpFromGraph(gof.Op):
null_gradients='return', null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads))) known_grads=OrderedDict(izip(local_outputs, output_grads)))
assert self._lop_type in ['lop', 'grad'],\
self.LOP_TYPE_ERR_MSG % self._lop_type
if self._lop_type == 'lop': if self._lop_type == 'lop':
callable_args = (local_inputs, local_outputs, output_grads) callable_args = (local_inputs, local_outputs, output_grads)
elif self._lop_type == 'grad': elif self._lop_type == 'grad':
callable_args = (local_inputs, output_grads) callable_args = (local_inputs, output_grads)
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
# we need to convert _lop_op into an OfG instance # we need to convert _lop_op into an OfG instance
if lop_op == 'default': if lop_op == 'default':
...@@ -640,10 +641,10 @@ class OpFromGraph(gof.Op): ...@@ -640,10 +641,10 @@ class OpFromGraph(gof.Op):
cpmat_self = io_connection_pattern( cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs) self.local_inputs, self.local_outputs)
grad_op = self.get_lop_op() lop_op = self.get_lop_op()
cpmat_grad = io_connection_pattern( cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:], lop_op.local_inputs[inp_len:],
grad_op.local_outputs) lop_op.local_outputs)
# cpmat_self |= cpmat_grad.T # cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected # cpmat_self &= out_is_disconnected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论