提交 430dd1f8 authored 作者: khaotik's avatar khaotik

fixes based on Fred's review

- safer check for int(0) - only use string for error message
上级 6de8aec7
...@@ -213,7 +213,7 @@ class OpFromGraph(gof.Op): ...@@ -213,7 +213,7 @@ class OpFromGraph(gof.Op):
name=None, **kwargs name=None, **kwargs
): ):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list, got %s' % type(outputs), outputs) raise TypeError('outputs must be list, got %s' % type(outputs))
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError( raise TypeError(
...@@ -297,7 +297,7 @@ class OpFromGraph(gof.Op): ...@@ -297,7 +297,7 @@ class OpFromGraph(gof.Op):
elif grad_op is None: elif grad_op is None:
all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_null_t()] * inp_len all_grads_ov_l = [self.ofg_null_t()] * inp_len
elif grad_op is 0: elif type(grad_op) is int and grad_op == 0:
all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_discon_t()] * inp_len all_grads_ov_l = [self.ofg_discon_t()] * inp_len
elif isinstance(grad_op, list): elif isinstance(grad_op, list):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论