bugs fixed toward replacing NumpyR

上级 45f56231
差异被折叠。
......@@ -58,6 +58,10 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint):
......@@ -460,8 +464,7 @@ class PythonOp(Op):
Op.__init__(self, inputs, self.gen_outputs())
def __validate__(self):
for input in self.inputs:
assert isinstance(input, ResultValue)
return all([ is_result(i) for i in self.inputs])
def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)]
......
import gof
from gof.lib import compute_from
from gof.lib import compute_from, is_result
import core
class Grad(object):
......@@ -173,7 +173,7 @@ class update_gradient_via_grad:
"""
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue):
if len(self.inputs) == 1 and is_result(inputgs):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论