提交 1c973723 authored 作者: James Bergstra's avatar James Bergstra

merged

......@@ -156,6 +156,15 @@ class T_Function(unittest.TestCase):
assert eval_outputs([e]) == 14.0
assert fast_compute(e) == 14.0
def test_closure(self):
x, y, z = tensor.scalars('xyz')
v = tensor.value(numpy.zeros(()))
e = x + tensor.add_inplace(v, 1)
f = function([x], [e])
assert f(1.) == 2.
assert f(1.) == 3.
assert f(1.) == 4.
def test_borrow_true(self):
x, y, z = tensor.scalars('xyz')
e = x + y + z
......
......@@ -91,6 +91,9 @@ class FunctionFactory:
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False):
if len(inputs) != len(set(inputs)):
print >>sys.stderr, "Warning: duplicate inputs"
for r in list(inputs) + list(outputs):
if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs)
if None is not optimizer:
optimizer(env)
......
......@@ -7,23 +7,6 @@ from op import Op
from type import Type
from graph import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1:
testcase = unittest.TestCase
......
......@@ -43,7 +43,6 @@ class Apply(object2):
self.outputs.append(output)
else:
raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output)
@deprecated
def default_output(self):
"""
Returns the default output for this Node, typically self.outputs[0].
......@@ -90,8 +89,14 @@ class Result(object2):
#__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None):
self.type = type
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index)
self.index = index
if name is not None and not isinstance(name, str):
raise TypeError("name must be a string", name)
self.name = name
def __str__(self):
if self.name is not None:
......@@ -165,27 +170,6 @@ def as_apply(x):
else:
raise TypeError("Cannot map %s to Apply" % x)
@deprecated
def inputs(o):
"""
@type o: list
@param o: output L{Result}s
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
results = set()
def seek(r):
op = r.owner
if op is None:
results.add(r)
else:
for input in op.inputs:
seek(input)
for output in o:
seek(output)
return results
def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first
@type start: deque
......@@ -227,7 +211,6 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
return rval_list
@utils.deprecated('gof.graph', 'is this function ever used?')
def inputs(result_list):
"""
@type result_list: list of L{Result}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论