提交 a9fc8c3e authored 作者: Olivier Breuleux's avatar Olivier Breuleux

ported Subtensor, fixed some tests

上级 eec75e98
差异被折叠。
...@@ -126,14 +126,14 @@ class _test_inputs(testcase): ...@@ -126,14 +126,14 @@ class _test_inputs(testcase):
# self.fail() # self.fail()
class _test_orphans(testcase): # class _test_orphans(testcase):
def test_straightforward(self): # def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) # r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) # node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5) # node2 = MyOp.make_node(node.outputs[0], r5)
orph = orphans([r1, r2], node2.outputs) # orph = orphans([r1, r2], node2.outputs)
self.failUnless(orph == [r5], orph) # self.failUnless(orph == [r5], orph)
class _test_as_string(testcase): class _test_as_string(testcase):
...@@ -215,15 +215,15 @@ def prenode(obj): ...@@ -215,15 +215,15 @@ def prenode(obj):
if isinstance(obj, Result): if isinstance(obj, Result):
if obj.owner: if obj.owner:
return [obj.owner] return [obj.owner]
if isinstance(obj, Op): if isinstance(obj, Apply):
return obj.inputs return obj.inputs
class _test_toposort(testcase): class _test_toposort(testcase):
def test0(self): def test0(self):
"""Test a simple graph""" """Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r2) o = MyOp.make_node(r1, r2)
o2 = MyOp(o.outputs[0], r5) o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all) self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
...@@ -234,45 +234,45 @@ class _test_toposort(testcase): ...@@ -234,45 +234,45 @@ class _test_toposort(testcase):
def test1(self): def test1(self):
"""Test a graph with double dependencies""" """Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1) o = MyOp.make_node(r1, r1)
o2 = MyOp(o.outputs[0], r5) o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all) self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
def test2(self): def test2(self):
"""Test a graph where the inputs have owners""" """Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp(r1, r1) o = MyOp.make_node(r1, r1)
r2b = o.outputs[0] r2b = o.outputs[0]
o2 = MyOp(r2b, r2b) o2 = MyOp.make_node(r2b, r2b)
all = io_toposort([r2b], o2.outputs) all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all) self.failUnless(all == [o2], all)
o2 = MyOp(r2b, r5) o2 = MyOp.make_node(r2b, r5)
all = io_toposort([r2b], o2.outputs) all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all) self.failUnless(all == [o2], all)
def test3(self): def test3(self):
"""Test a graph which is not connected""" """Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4) r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp(r3, r4) o1 = MyOp.make_node(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all) self.failUnless(all == [o1,o0], all)
def test4(self): def test4(self):
"""Test inputs and outputs mixed together in a chain graph""" """Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4) r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp(o0.outputs[0], r1) o1 = MyOp.make_node(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all) self.failUnless(all == [o1], all)
def test5(self): def test5(self):
"""Test when outputs have clients""" """Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4) r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp(o0.outputs[0], r4) o1 = MyOp.make_node(o0.outputs[0], r4)
all = io_toposort([], o0.outputs) all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all) self.failUnless(all == [o0], all)
......
...@@ -18,26 +18,26 @@ import traceback ...@@ -18,26 +18,26 @@ import traceback
def compile_dir(): def compile_dir():
"""Return the directory in which scipy.weave should store code objects. """Return the directory in which scipy.weave should store code objects.
If the environment variable OMEGA_COMPILEDIR is set, its value is returned. If the environment variable THEANO_COMPILEDIR is set, its value is returned.
If not, a directory of the form $HOME/.omega/compiledir_<platform Id>. If not, a directory of the form $HOME/.theano/compiledir_<platform Id>.
As a test, this function touches the file __init__.py in the returned As a test, this function touches the file __init__.py in the returned
directory, and raises OSError if there's a problem. directory, and raises OSError if there's a problem.
A directory coming from OMEGA_COMPILEDIR is not created automatically, but A directory coming from THEANO_COMPILEDIR is not created automatically, but
a directory in $HOME/.omega is created automatically. a directory in $HOME/.theano is created automatically.
This directory is appended to the sys.path search path before being This directory is appended to the sys.path search path before being
returned, if the touch was successful. returned, if the touch was successful.
""" """
if os.getenv('OMEGA_COMPILEDIR'): if os.getenv('THEANO_COMPILEDIR'):
cachedir = os.getenv('OMEGA_COMPILEDIR') cachedir = os.getenv('THEANO_COMPILEDIR')
else: else:
# use (and possibly create) a default code cache location # use (and possibly create) a default code cache location
platform_id = platform.platform() + '-' + platform.processor() platform_id = platform.platform() + '-' + platform.processor()
import re import re
platform_id = re.sub("[\(\)\s]+", "_", platform_id) platform_id = re.sub("[\(\)\s]+", "_", platform_id)
cachedir = os.path.join(os.getenv('HOME'), '.omega', 'compiledir_'+platform_id) cachedir = os.path.join(os.getenv('HOME'), '.theano', 'compiledir_'+platform_id)
if not os.access(cachedir, os.R_OK | os.W_OK): if not os.access(cachedir, os.R_OK | os.W_OK):
#this may raise a number of problems, I think all of which are serious. #this may raise a number of problems, I think all of which are serious.
os.makedirs(cachedir, 7<<6) os.makedirs(cachedir, 7<<6)
...@@ -345,7 +345,7 @@ class CLinker(Linker): ...@@ -345,7 +345,7 @@ class CLinker(Linker):
env = self.env env = self.env
self.inputs = env.inputs self.inputs = env.inputs
self.outputs = env.outputs self.outputs = env.outputs
self.results = list(env.results) self.results = graph.results(self.inputs, self.outputs) # list(env.results)
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.results if isinstance(r, Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs)) self.orphans = list(r for r in self.results if isinstance(r, Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans)) self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
...@@ -390,8 +390,8 @@ class CLinker(Linker): ...@@ -390,8 +390,8 @@ class CLinker(Linker):
id = 1 id = 1
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
for result in set(self.results): for result in self.results:
# it might be possible to inline constant results as C literals # it might be possible to inline constant results as C literals
## if getattr(result, 'constant', False): ## if getattr(result, 'constant', False):
......
...@@ -294,7 +294,7 @@ class Env(object): #(graph.Graph): ...@@ -294,7 +294,7 @@ class Env(object): #(graph.Graph):
for feature in env._features: for feature in env._features:
if hasattr(feature, 'orderings'): if hasattr(feature, 'orderings'):
for op, prereqs in feature.orderings(env).items(): for op, prereqs in feature.orderings(env).items():
ords.setdefault(op, set()).update(prereqs) ords.setdefault(op, []).extend(prereqs)
order = graph.io_toposort(env.inputs, env.outputs, ords) order = graph.io_toposort(env.inputs, env.outputs, ords)
return order return order
......
...@@ -283,23 +283,36 @@ def inputs(result_list): ...@@ -283,23 +283,36 @@ def inputs(result_list):
def results_and_orphans(i, o): def results_and_orphans(i, o):
results = set() """
orphans = set() """
def helper(r): def expand(r):
if r in results: if r.owner and r not in i:
return l = list(r.owner.inputs)
results.add(r) l.reverse()
if r.owner is None: return l
if r not in i: results = stack_search(deque(o), expand, 'dfs')
orphans.add(r) orphans = [r for r in results if r.owner is None and r not in i]
else:
for r2 in r.owner.inputs:
helper(r2)
for output in o:
helper(output)
return results, orphans return results, orphans
#def results_and_orphans(i, o):
# results = set()
# orphans = set()
# def helper(r):
# if r in results:
# return
# results.add(r)
# if r.owner is None:
# if r not in i:
# orphans.add(r)
# else:
# for r2 in r.owner.inputs:
# helper(r2)
# for output in o:
# helper(output)
# return results, orphans
def ops(i, o): def ops(i, o):
""" """
@type i: list @type i: list
...@@ -469,17 +482,17 @@ def io_toposort(i, o, orderings = {}): ...@@ -469,17 +482,17 @@ def io_toposort(i, o, orderings = {}):
def deps(obj): def deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
if isinstance(obj, result.Result): if isinstance(obj, Result):
if obj.owner: if obj.owner:
rval = [obj.owner] rval = [obj.owner]
if isinstance(obj, op.Op): if isinstance(obj, Apply):
rval = list(obj.inputs) rval = list(obj.inputs)
rval.extend(orderings.get(obj, [])) rval.extend(orderings.get(obj, []))
else: else:
assert not orderings.get(obj, []) assert not orderings.get(obj, [])
return rval return rval
topo = general_toposort(o, deps) topo = general_toposort(o, deps)
return [o for o in topo if isinstance(o, op.Op)] return [o for o in topo if isinstance(o, Apply)]
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论