提交 1c973723 authored 作者: James Bergstra's avatar James Bergstra

merged

...@@ -156,6 +156,15 @@ class T_Function(unittest.TestCase): ...@@ -156,6 +156,15 @@ class T_Function(unittest.TestCase):
assert eval_outputs([e]) == 14.0 assert eval_outputs([e]) == 14.0
assert fast_compute(e) == 14.0 assert fast_compute(e) == 14.0
def test_closure(self):
x, y, z = tensor.scalars('xyz')
v = tensor.value(numpy.zeros(()))
e = x + tensor.add_inplace(v, 1)
f = function([x], [e])
assert f(1.) == 2.
assert f(1.) == 3.
assert f(1.) == 4.
def test_borrow_true(self): def test_borrow_true(self):
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
e = x + y + z e = x + y + z
......
...@@ -91,6 +91,9 @@ class FunctionFactory: ...@@ -91,6 +91,9 @@ class FunctionFactory:
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False): def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False):
if len(inputs) != len(set(inputs)): if len(inputs) != len(set(inputs)):
print >>sys.stderr, "Warning: duplicate inputs" print >>sys.stderr, "Warning: duplicate inputs"
for r in list(inputs) + list(outputs):
if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs) env = std_env(inputs, outputs)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
......
...@@ -7,23 +7,6 @@ from op import Op ...@@ -7,23 +7,6 @@ from op import Op
from type import Type from type import Type
from graph import Result from graph import Result
def inputs(result_list):
"""
@type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a
left-recursive depth-first search started at the L{Result}s in result_list.
"""
def expand(r):
if r.owner:
l = list(r.owner.inputs)
l.reverse()
return l
dfs_results = stack_search(deque(result_list), expand, 'dfs')
rval = [r for r in dfs_results if r.owner is None]
#print rval, _orig_inputs(o)
return rval
if 1: if 1:
testcase = unittest.TestCase testcase = unittest.TestCase
......
...@@ -43,7 +43,6 @@ class Apply(object2): ...@@ -43,7 +43,6 @@ class Apply(object2):
self.outputs.append(output) self.outputs.append(output)
else: else:
raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output) raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output)
@deprecated
def default_output(self): def default_output(self):
""" """
Returns the default output for this Node, typically self.outputs[0]. Returns the default output for this Node, typically self.outputs[0].
...@@ -90,8 +89,14 @@ class Result(object2): ...@@ -90,8 +89,14 @@ class Result(object2):
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None): def __init__(self, type, owner = None, index = None, name = None):
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner self.owner = owner
if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index)
self.index = index self.index = index
if name is not None and not isinstance(name, str):
raise TypeError("name must be a string", name)
self.name = name self.name = name
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
...@@ -165,27 +170,6 @@ def as_apply(x): ...@@ -165,27 +170,6 @@ def as_apply(x):
else: else:
raise TypeError("Cannot map %s to Apply" % x) raise TypeError("Cannot map %s to Apply" % x)
@deprecated
def inputs(o):
"""
@type o: list
@param o: output L{Result}s
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
results = set()
def seek(r):
op = r.owner
if op is None:
results.add(r)
else:
for input in op.inputs:
seek(input)
for output in o:
seek(output)
return results
def stack_search(start, expand, mode='bfs', build_inv = False): def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first """Search through L{Result}s, either breadth- or depth-first
@type start: deque @type start: deque
...@@ -227,7 +211,6 @@ def stack_search(start, expand, mode='bfs', build_inv = False): ...@@ -227,7 +211,6 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
return rval_list return rval_list
@utils.deprecated('gof.graph', 'is this function ever used?')
def inputs(result_list): def inputs(result_list):
""" """
@type result_list: list of L{Result} @type result_list: list of L{Result}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论