提交 0350b33a authored 作者: khaotik's avatar khaotik

microrefinements

- use `def` syntax for flake8 syntax issue - cleaner condition - use assert instead of raise for internal state
上级 90ba4b94
......@@ -284,13 +284,13 @@ class OpFromGraph(gof.Op):
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
cond = int(lop_overrides == 'default') * 2 + int(grad_overrides == 'default')
if cond == 0:
raise ValueError('lop_overrides and grad_overrides are mutually exclusive')
elif cond == 1:
self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop'
elif cond == 2:
if lop_overrides != 'default':
if grad_overrides != 'default':
raise ValueError('lop_overrides and grad_overrides are mutually exclusive')
else:
self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop'
elif grad_overrides != 'default':
self.set_lop_overrides(grad_overrides)
self._lop_type = 'grad'
else:
......@@ -329,6 +329,8 @@ class OpFromGraph(gof.Op):
if isinstance(lop_op, OpFromGraph):
if self._lop_op_is_cached:
return
assert self._lop_type in ['lop', 'grad'],\
self.LOP_TYPE_ERR_MSG % self._lop_type
if self._lop_type == 'grad':
needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs)
......@@ -336,7 +338,8 @@ class OpFromGraph(gof.Op):
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable
lop_op = lambda inps, grads: self._lop_op(*(inps + grads)) # flake8: noqa
def lop_op(inps, grads):
return self._lop_op(*(inps + grads))
elif self._lop_type == 'lop':
# OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs)
......@@ -348,8 +351,6 @@ class OpFromGraph(gof.Op):
self._lop_op_stypes_l = [None] * inp_len
self._lop_op.kwargs['on_unused_input'] = 'ignore'
return
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial(
......@@ -360,12 +361,12 @@ class OpFromGraph(gof.Op):
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
assert self._lop_type in ['lop', 'grad'],\
self.LOP_TYPE_ERR_MSG % self._lop_type
if self._lop_type == 'lop':
callable_args = (local_inputs, local_outputs, output_grads)
elif self._lop_type == 'grad':
callable_args = (local_inputs, output_grads)
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
# we need to convert _lop_op into an OfG instance
if lop_op == 'default':
......@@ -640,10 +641,10 @@ class OpFromGraph(gof.Op):
cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs)
grad_op = self.get_lop_op()
lop_op = self.get_lop_op()
cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:],
grad_op.local_outputs)
lop_op.local_inputs[inp_len:],
lop_op.local_outputs)
# cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论