bugs fixed toward replacing NumpyR

上级 45f56231
差异被折叠。
...@@ -58,6 +58,10 @@ def compute(*nodes): ...@@ -58,6 +58,10 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way).""" """Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set()) compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
...@@ -460,8 +464,7 @@ class PythonOp(Op): ...@@ -460,8 +464,7 @@ class PythonOp(Op):
Op.__init__(self, inputs, self.gen_outputs()) Op.__init__(self, inputs, self.gen_outputs())
def __validate__(self): def __validate__(self):
for input in self.inputs: return all([ is_result(i) for i in self.inputs])
assert isinstance(input, ResultValue)
def gen_outputs(self): def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)] return [ResultValue() for i in xrange(self.nout)]
......
import gof import gof
from gof.lib import compute_from from gof.lib import compute_from, is_result
import core import core
class Grad(object): class Grad(object):
...@@ -173,7 +173,7 @@ class update_gradient_via_grad: ...@@ -173,7 +173,7 @@ class update_gradient_via_grad:
""" """
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue): if len(self.inputs) == 1 and is_result(inputgs):
inputgs = [inputgs] inputgs = [inputgs]
else: else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论