提交 1e73a0bf authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for theano/gof/tests/test_graph_opt_caching.py

上级 7cb9d724
import unittest, os import os
import numpy import numpy
import six.moves.cPickle as pickle
from theano.compat import DictMixin, OrderedDict
import theano import theano
import theano.tensor as T import theano.tensor as T
floatX = 'float32' floatX = 'float32'
def test_graph_opt_caching(): def test_graph_opt_caching():
opt_db_file = theano.config.compiledir+'/optimized_graphs.pkl' opt_db_file = theano.config.compiledir + '/optimized_graphs.pkl'
os.system('rm %s'%opt_db_file) os.system('rm %s' % opt_db_file)
mode = theano.config.mode mode = theano.config.mode
if mode in ["DEBUG_MODE", "DebugMode"]: if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN" mode = "FAST_RUN"
...@@ -30,12 +29,12 @@ def test_graph_opt_caching(): ...@@ -30,12 +29,12 @@ def test_graph_opt_caching():
q = theano.shared(numpy.ones((10, 10), dtype=floatX)) q = theano.shared(numpy.ones((10, 10), dtype=floatX))
j = T.sum(T.sum(T.sum(m ** 2 + n) + p) + q) j = T.sum(T.sum(T.sum(m ** 2 + n) + p) + q)
f2 = theano.function([m, n], j, mode=mode) f2 = theano.function([m, n], j, mode=mode)
in1 = numpy.ones((10, 10), dtype=floatX) in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX) in2 = numpy.ones((10, 10), dtype=floatX)
assert f1(in1, in2) == f2(in1, in2) assert f1(in1, in2) == f2(in1, in2)
finally: finally:
theano.config.cache_optimizations = default theano.config.cache_optimizations = default
if __name__ == '__main__': if __name__ == '__main__':
test_graph_opt_caching() test_graph_opt_caching()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论