提交 51408b9b authored 作者: Benjamin Scellier's avatar Benjamin Scellier 提交者: Nicolas Ballas

file theano/gof/tests/test_graph_opt_caching.py

上级 77c42664
from __future__ import absolute_import, print_function, division
import os
import numpy
import numpy as np
import theano
import theano.tensor as T
......@@ -19,20 +19,20 @@ def test_graph_opt_caching():
theano.config.cache_optimizations = True
a = T.fmatrix('a')
b = T.fmatrix('b')
c = theano.shared(numpy.ones((10, 10), dtype=floatX))
d = theano.shared(numpy.ones((10, 10), dtype=floatX))
c = theano.shared(np.ones((10, 10), dtype=floatX))
d = theano.shared(np.ones((10, 10), dtype=floatX))
e = T.sum(T.sum(T.sum(a ** 2 + b) + c) + d)
f1 = theano.function([a, b], e, mode=mode)
m = T.fmatrix('x1')
n = T.fmatrix('x2')
p = theano.shared(numpy.ones((10, 10), dtype=floatX))
q = theano.shared(numpy.ones((10, 10), dtype=floatX))
p = theano.shared(np.ones((10, 10), dtype=floatX))
q = theano.shared(np.ones((10, 10), dtype=floatX))
j = T.sum(T.sum(T.sum(m ** 2 + n) + p) + q)
f2 = theano.function([m, n], j, mode=mode)
in1 = numpy.ones((10, 10), dtype=floatX)
in2 = numpy.ones((10, 10), dtype=floatX)
in1 = np.ones((10, 10), dtype=floatX)
in2 = np.ones((10, 10), dtype=floatX)
assert f1(in1, in2) == f2(in1, in2)
finally:
theano.config.cache_optimizations = default
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论