提交 f8d377f3 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/compile/tests/test_builders.py

上级 4deb04c9
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy import numpy as np
from theano import config, shared from theano import config, shared
...@@ -23,14 +23,14 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -23,14 +23,14 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = numpy.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
# print function, function.__module__ # print function, function.__module__
# print fn.maker.fgraph.toposort() # print fn.maker.fgraph.toposort()
fn(xv, yv, zv) fn(xv, yv, zv)
assert numpy.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
assert numpy.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
def test_size_changes(self): def test_size_changes(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
...@@ -38,15 +38,15 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -38,15 +38,15 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
op = OpFromGraph([x, y], [e]) op = OpFromGraph([x, y], [e])
f = op(x, op(y, z)) f = op(x, op(y, z))
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 3), dtype=config.floatX) xv = np.ones((2, 3), dtype=config.floatX)
yv = numpy.ones((3, 4), dtype=config.floatX) * 3 yv = np.ones((3, 4), dtype=config.floatX) * 3
zv = numpy.ones((4, 5), dtype=config.floatX) * 5 zv = np.ones((4, 5), dtype=config.floatX) * 5
res = fn(xv, yv, zv) res = fn(xv, yv, zv)
assert res.shape == (2, 5) assert res.shape == (2, 5)
assert numpy.all(180.0 == res) assert np.all(180.0 == res)
res = fn(xv, yv, zv) res = fn(xv, yv, zv)
assert res.shape == (2, 5) assert res.shape == (2, 5)
assert numpy.all(180.0 == res) assert np.all(180.0 == res)
def test_grad(self): def test_grad(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
...@@ -55,10 +55,10 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -55,10 +55,10 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = numpy.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert numpy.all(11.0 == fn(xv, yv, zv)) assert np.all(11.0 == fn(xv, yv, zv))
def test_grad_grad(self): def test_grad_grad(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
...@@ -68,46 +68,46 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -68,46 +68,46 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = numpy.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert numpy.allclose(6.0, fn(xv, yv, zv)) assert np.allclose(6.0, fn(xv, yv, zv))
def test_shared(self): def test_shared(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = numpy.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
# print function, function.__module__ # print function, function.__module__
# print fn.maker.fgraph.toposort() # print fn.maker.fgraph.toposort()
assert numpy.allclose(8.0, fn(xv, yv, zv)) assert np.allclose(8.0, fn(xv, yv, zv))
assert numpy.allclose(8.0, fn(xv, yv, zv)) assert np.allclose(8.0, fn(xv, yv, zv))
def test_shared_grad(self): def test_shared_grad(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = numpy.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert numpy.allclose(11.0 + s.get_value(), fn(xv, yv, zv)) assert np.allclose(11.0 + s.get_value(), fn(xv, yv, zv))
# grad again the shared variable # grad again the shared variable
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), s) f = f - T.grad(T.sum(f), s)
fn = function([x, y, z], f) fn = function([x, y, z], f)
assert numpy.allclose(15.0 + s.get_value(), assert np.allclose(15.0 + s.get_value(),
fn(xv, yv, zv)) fn(xv, yv, zv))
def test_connection_pattern(self): def test_connection_pattern(self):
...@@ -163,6 +163,6 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -163,6 +163,6 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
p = T.matrix('p') p = T.matrix('p')
self._compile_and_check([q, p], self._compile_and_check([q, p],
op_graph(q, p), op_graph(q, p),
[numpy.ones([3, 4], dtype=config.floatX), [np.ones([3, 4], dtype=config.floatX),
numpy.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
OpFromGraph) OpFromGraph)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论