提交 85a1ae51 authored 作者: khaotik's avatar khaotik 提交者: khaotik

added nested/grad_override test and params

上级 5e6cefc2
...@@ -8,17 +8,22 @@ from theano.compile import function ...@@ -8,17 +8,22 @@ from theano.compile import function
from theano import tensor as T from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraph from theano.compile.builders import OpFromGraphInline, OpFromGraphPrecompiled
from theano.tests import unittest_tools from theano.tests import unittest_tools
test_params = unittest_tools.parameterized.expand(
[(OpFromGraphInline,), (OpFromGraphPrecompiled,)])
class T_OpFromGraph(unittest_tools.InferShapeTester): class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_straightforward(self):
@test_params
def test_straightforward(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
...@@ -32,10 +37,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -32,10 +37,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
assert np.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
def test_size_changes(self): @test_params
def test_size_changes(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = T.dot(x, y) e = T.dot(x, y)
op = OpFromGraph([x, y], [e]) op = cls_ofg([x, y], [e])
f = op(x, op(y, z)) f = op(x, op(y, z))
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = np.ones((2, 3), dtype=config.floatX) xv = np.ones((2, 3), dtype=config.floatX)
...@@ -48,10 +54,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -48,10 +54,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert res.shape == (2, 5) assert res.shape == (2, 5)
assert np.all(180.0 == res) assert np.all(180.0 == res)
def test_grad(self): @test_params
def test_grad(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -60,10 +67,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -60,10 +67,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.all(11.0 == fn(xv, yv, zv)) assert np.all(11.0 == fn(xv, yv, zv))
def test_grad_grad(self): @test_params
def test_grad_grad(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
...@@ -73,11 +81,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -73,11 +81,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.allclose(6.0, fn(xv, yv, zv)) assert np.allclose(6.0, fn(xv, yv, zv))
def test_shared(self): @test_params
def test_shared(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(np.random.rand(2, 2).astype(config.floatX)) s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
...@@ -90,11 +99,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -90,11 +99,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(8.0, fn(xv, yv, zv)) assert np.allclose(8.0, fn(xv, yv, zv))
assert np.allclose(8.0, fn(xv, yv, zv)) assert np.allclose(8.0, fn(xv, yv, zv))
def test_shared_grad(self): @test_params
def test_shared_grad(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(np.random.rand(2, 2).astype(config.floatX)) s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -110,13 +120,76 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -110,13 +120,76 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(15.0 + s.get_value(), assert np.allclose(15.0 + s.get_value(),
fn(xv, yv, zv)) fn(xv, yv, zv))
def test_connection_pattern(self): @test_params
def test_grad_override(self, cls_ofg):
x,y = T.vectors('xy')
def go(args):
x, y, g = args
return [g*y*2, g*x*1.5]
# no override is coverd in "grad" test
# single override
op_mul = cls_ofg([x, y], [x*y], grad_overrides=go)
xx,yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx,yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = numpy.random.rand(16).astype(config.floatX)
yv = numpy.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert numpy.allclose(yv*2, dxv)
assert numpy.allclose(xv*1.5, dyv)
# list override
def go1(args):
x, w, b, g = args
return g*w*2
def go2(args):
x, w, b, g = args
return g*x*1.5
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x*w+b], grad_overrides=[go1, go2])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
fn = function([xx, ww, bb], [dx, dw, db])
xv = numpy.random.rand(16).astype(config.floatX)
wv = numpy.random.rand(16).astype(config.floatX)
bv = numpy.random.rand(16).astype(config.floatX)
dxv, dwv, dbv = fn(xv, wv, bv)
assert numpy.allclose(wv*2, dxv)
assert numpy.allclose(xv*1.5, dwv)
assert numpy.allclose(numpy.ones(16, dtype=config.floatX), dbv)
@test_params
def test_nested(self, cls_ofg):
x, y = T.vectors('xy')
u, v = x+y, x-y
op_ft = cls_ofg([x, y], [u, v])
op_ift = cls_ofg([x, y], [u/2, v/2])
xx, yy = T.vector('xx'), T.vector('yy')
xx2, yy2 = op_ift(*op_ft(xx, yy))
fn = function([xx, yy], [xx2, yy2])
xv = numpy.random.rand(16).astype(config.floatX)
yv = numpy.random.rand(16).astype(config.floatX)
xv2, yv2 = fn(xv, yv)
assert numpy.allclose(xv, xv2)
assert numpy.allclose(yv, yv2)
@test_params
def test_connection_pattern(self, cls_ofg):
# Basic case # Basic case
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
out1 = x * y out1 = x * y
out2 = y * z out2 = y * z
op1 = OpFromGraph([x, y, z], [out1, out2]) op1 = cls_ofg([x, y, z], [out1, out2])
results = op1.connection_pattern(None) results = op1.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
[True, True], [True, True],
...@@ -128,7 +201,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -128,7 +201,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
m, n, p, q = T.matrices('mnpq') m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p) o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2) out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2]) op2 = cls_ofg([m, n, p, q], [out1, out2])
results = op2.connection_pattern(None) results = op2.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
...@@ -144,7 +217,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -144,7 +217,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
out1 = x + rv_u out1 = x + rv_u
out2 = y + 3 out2 = y + 3
out3 = 3 + rv_u out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3]) op3 = cls_ofg([x, y], [out1, out2, out3])
results = op3.connection_pattern(None) results = op3.connection_pattern(None)
expect_result = [[True, False, False], expect_result = [[True, False, False],
...@@ -152,17 +225,23 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -152,17 +225,23 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
[True, False, True]] [True, False, True]]
assert results == expect_result assert results == expect_result
def test_infer_shape(self): @test_params
def test_infer_shape(self, cls_ofg):
x = T.matrix('x') x = T.matrix('x')
y = T.matrix('y') y = T.matrix('y')
o1 = x + y o1 = x + y
o2 = x * y o2 = x * y
op_graph = OpFromGraph([x, y], [o1, o2]) op_graph = cls_ofg([x, y], [o1, o2])
q = T.matrix('q') q = T.matrix('q')
p = T.matrix('p') p = T.matrix('p')
# we don't want check_topo for inline ops
# since the inline op is replaced during optimization
is_compile = not issubclass(cls_ofg, OpFromGraphInline)
self._compile_and_check([q, p], self._compile_and_check([q, p],
op_graph(q, p), op_graph(q, p),
[np.ones([3, 4], dtype=config.floatX), [np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
OpFromGraph) cls_ofg,
check_topo=is_compile)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论