提交 4b439eaf authored 作者: Frederic's avatar Frederic

pep8

上级 6e3e96e4
...@@ -9,13 +9,12 @@ from cStringIO import StringIO ...@@ -9,13 +9,12 @@ from cStringIO import StringIO
from theano import config from theano import config
from theano import shared from theano import shared
def sharedX(x, name=None): def sharedX(x, name=None):
x = np.cast[config.floatX](x) x = np.cast[config.floatX](x)
return shared(x, name) return shared(x, name)
def test_determinism_1(): def test_determinism_1():
# Tests that repeatedly running a script that compiles and # Tests that repeatedly running a script that compiles and
...@@ -24,7 +23,7 @@ def test_determinism_1(): ...@@ -24,7 +23,7 @@ def test_determinism_1():
# change. # change.
# This specific script is capable of catching a bug where # This specific script is capable of catching a bug where
# FunctionGraph.toposort was non-deterministic. # FunctionGraph.toposort was non-deterministic.
def run(replay, log = None): def run(replay, log=None):
if not replay: if not replay:
log = StringIO() log = StringIO()
...@@ -32,7 +31,6 @@ def test_determinism_1(): ...@@ -32,7 +31,6 @@ def test_determinism_1():
log = StringIO(log) log = StringIO(log)
record = Record(replay=replay, file_object=log) record = Record(replay=replay, file_object=log)
disturb_mem.disturb_mem() disturb_mem.disturb_mem()
mode = RecordMode(record=record) mode = RecordMode(record=record)
...@@ -53,19 +51,20 @@ def test_determinism_1(): ...@@ -53,19 +51,20 @@ def test_determinism_1():
v_range.max(), v_range.max(),
]): ]):
disturb_mem.disturb_mem() disturb_mem.disturb_mem()
s = sharedX(0., name='s_'+str(i)) s = sharedX(0., name='s_' + str(i))
updates.append((s, val)) updates.append((s, val))
for var in theano.gof.graph.ancestors(update for var, update in updates): for var in theano.gof.graph.ancestors(update for _, update in updates):
if var.name is not None and var.name is not 'b': if var.name is not None and var.name is not 'b':
if var.name[0] != 's' or len(var.name) != 2: if var.name[0] != 's' or len(var.name) != 2:
var.name = None var.name = None
for key in channels: for key in channels:
updates.append((s, channels[key])) updates.append((s, channels[key]))
f = theano.function([], mode=mode, updates=updates, on_unused_input='ignore', name='f') f = theano.function([], mode=mode, updates=updates,
on_unused_input='ignore', name='f')
for output in f.maker.fgraph.outputs: for output in f.maker.fgraph.outputs:
mode.record.handle_line(var_descriptor(output)+'\n') mode.record.handle_line(var_descriptor(output) + '\n')
disturb_mem.disturb_mem() disturb_mem.disturb_mem()
f() f()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论