提交 47548c0a authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5255 from khaotik/new_opfromgraph

Upgraded OpFromGraph with inline support and gradient override
......@@ -68,7 +68,8 @@ from theano.compile import (
SymbolicOutput, Out,
Mode,
predefined_modes, predefined_linkers, predefined_optimizers,
FunctionMaker, function, function_dump, OpFromGraph,
FunctionMaker, function, function_dump,
OpFromGraph,
ProfileStats,
Param, shared, as_op)
......
差异被折叠。
from __future__ import absolute_import, print_function, division
from functools import partial
import numpy as np
from theano import config, shared
from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
from theano.compile import function
from theano import tensor as T
......@@ -12,13 +15,17 @@ from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools
test_params = unittest_tools.parameterized.expand(
[(OpFromGraph,), (partial(OpFromGraph, inline=True),)])
class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_straightforward(self):
@test_params
def test_straightforward(self, cls_ofg):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e])
op = cls_ofg([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x)
......@@ -32,10 +39,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.all(8.0 == fn(xv, yv, zv))
assert np.all(8.0 == fn(xv, yv, zv))
def test_size_changes(self):
@test_params
def test_size_changes(self, cls_ofg):
x, y, z = T.matrices('xyz')
e = T.dot(x, y)
op = OpFromGraph([x, y], [e])
op = cls_ofg([x, y], [e])
f = op(x, op(y, z))
fn = function([x, y, z], f)
xv = np.ones((2, 3), dtype=config.floatX)
......@@ -48,10 +56,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert res.shape == (2, 5)
assert np.all(180.0 == res)
def test_grad(self):
@test_params
def test_grad(self, cls_ofg):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e])
op = cls_ofg([x, y, z], [e])
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f)
......@@ -60,10 +69,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.all(11.0 == fn(xv, yv, zv))
def test_grad_grad(self):
@test_params
def test_grad_grad(self, cls_ofg):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e])
op = cls_ofg([x, y, z], [e])
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
f = f - T.grad(T.sum(f), y)
......@@ -73,11 +83,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.allclose(6.0, fn(xv, yv, zv))
def test_shared(self):
@test_params
def test_shared(self, cls_ofg):
x, y, z = T.matrices('xyz')
s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
op = OpFromGraph([x, y, z], [e])
op = cls_ofg([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x)
......@@ -90,11 +101,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(8.0, fn(xv, yv, zv))
assert np.allclose(8.0, fn(xv, yv, zv))
def test_shared_grad(self):
@test_params
def test_shared_grad(self, cls_ofg):
x, y, z = T.matrices('xyz')
s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
op = OpFromGraph([x, y, z], [e])
op = cls_ofg([x, y, z], [e])
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f)
......@@ -110,13 +122,146 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(15.0 + s.get_value(),
fn(xv, yv, zv))
def test_connection_pattern(self):
@test_params
def test_grad_override(self, cls_ofg):
x, y = T.vectors('xy')
def go(inps, gs):
x, y = inps
g, = gs
return [g * y * 2, g * x * 1.5]
dedz = T.vector('dedz')
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
# single override case (function or OfG instance)
xx, yy = T.vector('xx'), T.vector('yy')
for op in [op_mul, op_mul2]:
zz = T.sum(op(xx, yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert np.allclose(yv * 2, dxv)
assert np.allclose(xv * 1.5, dyv)
# list override case
def go1(inps, gs):
x, w, b = inps
g = gs[0]
return g * w * 2
def go2(inps, gs):
x, w, b = inps
g = gs[0]
return g * x * 1.5
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
fn = function([xx, ww, bb], [dx, dw, db])
xv = np.random.rand(16).astype(config.floatX)
wv = np.random.rand(16).astype(config.floatX)
bv = np.random.rand(16).astype(config.floatX)
dxv, dwv, dbv = fn(xv, wv, bv)
assert np.allclose(wv * 2, dxv)
assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
# NullType and DisconnectedType
op_linear2 = cls_ofg(
[x, w, b], [x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()])
zz2 = T.sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb],
return_disconnected='Disconnected',
disconnected_inputs='ignore',
null_gradients='return')
assert isinstance(dx2.type, T.TensorType)
assert dx2.ndim == 1
assert isinstance(dw2.type, NullType)
assert isinstance(db2.type, DisconnectedType)
@test_params
def test_rop(self, cls_ofg):
a = T.vector()
M = T.matrix()
b = T.dot(a, M)
op_matmul = cls_ofg([a, M], [b])
x = T.vector()
W = T.matrix()
y = op_matmul(x, W)
du = T.vector()
dv = T.Rop(y, x, du)
fn = function([x, W, du], dv)
xval = np.random.rand(16).astype(config.floatX)
Wval = np.random.rand(16, 16).astype(config.floatX)
duval = np.random.rand(16).astype(config.floatX)
dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
assert np.allclose(dvval2, dvval)
@test_params
def test_rop_override(self, cls_ofg):
x, y = T.vectors('xy')
def ro(inps, epts):
x, y = inps
u, v = epts
return [u * y * 2. + x * v * 1.5]
u, v = T.vectors('uv')
op_mul_rop = cls_ofg([x, y, u, v], ro([x, y], [u, v]))
op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
op_mul2 = cls_ofg([x, y], [x * y], rop_overrides=op_mul_rop)
# single override case
xx, yy = T.vector('xx'), T.vector('yy')
du, dv = T.vector('du'), T.vector('dv')
for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy)
dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw)
vals = np.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert np.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
# TODO list override case
@test_params
def test_nested(self, cls_ofg):
x, y = T.vectors('xy')
u, v = x + y, x - y
op_ft = cls_ofg([x, y], [u, v])
op_ift = cls_ofg([x, y], [u / 2, v / 2])
xx, yy = T.vector('xx'), T.vector('yy')
xx2, yy2 = op_ift(*op_ft(xx, yy))
fn = function([xx, yy], [xx2, yy2])
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
xv2, yv2 = fn(xv, yv)
assert np.allclose(xv, xv2)
assert np.allclose(yv, yv2)
@test_params
def test_connection_pattern(self, cls_ofg):
# Basic case
x, y, z = T.matrices('xyz')
out1 = x * y
out2 = y * z
op1 = OpFromGraph([x, y, z], [out1, out2])
op1 = cls_ofg([x, y, z], [out1, out2])
results = op1.connection_pattern(None)
expect_result = [[True, False],
[True, True],
......@@ -128,7 +273,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2])
op2 = cls_ofg([m, n, p, q], [out1, out2])
results = op2.connection_pattern(None)
expect_result = [[True, False],
......@@ -144,7 +289,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
out1 = x + rv_u
out2 = y + 3
out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3])
op3 = cls_ofg([x, y], [out1, out2, out3])
results = op3.connection_pattern(None)
expect_result = [[True, False, False],
......@@ -153,6 +298,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert results == expect_result
def test_infer_shape(self):
# test infer shape does not need to against inline case
# since the Op is remove during optimization phase
x = T.matrix('x')
y = T.matrix('y')
o1 = x + y
......
......@@ -244,15 +244,14 @@ class PyDotFormatter(object):
# Inputs mapping
ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x)
for x in node.op.fn.maker.fgraph.inputs]
for x in node.op.local_inputs]
assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs))
pd_node.get_attributes()['subg_map_inputs'] = h
# Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = node.op.fn.maker.fgraph.outputs
int_outputs = [gf.__node_id(x) for x in int_outputs]
int_outputs = [gf.__node_id(x) for x in node.op.local_outputs]
assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs))
pd_node.get_attributes()['subg_map_outputs'] = h
......
......@@ -1099,7 +1099,10 @@ def io_connection_pattern(inputs, outputs):
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var[out]
out_connection_pattern = connect_pattern_by_var.get(out)
if out_connection_pattern is None:
# the output is completely isolated from inputs
out_connection_pattern = [False] * len(inputs)
for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
......
......@@ -340,14 +340,22 @@ class TestAutoName:
assert r2.auto_name == "auto_" + str(autoname_id + 1)
def test_constant(self):
# Make sure the value we will use for the test aren't yet in the cache.
r1 = tensor.constant(1.5)
del tensor.constant_cache[r1.signature()]
r1 = tensor.constant(1.6)
del tensor.constant_cache[r1.signature()]
# Get counter value
autoname_id = next(Variable.__count__)
Variable.__count__ = count(autoname_id)
r1 = tensor.constant(1.5)
r2 = tensor.constant(1.5)
assert r1.auto_name == "auto_" + str(autoname_id)
assert r1.auto_name == "auto_" + str(autoname_id), (
r1.auto_name, "auto_" + str(autoname_id))
# We reuse the same variable
assert r2.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id), (
r2.auto_name, "auto_" + str(autoname_id))
assert r1 is r2
r3 = tensor.constant(1.6)
......
......@@ -1201,58 +1201,56 @@ def _populate_grad_dict(var_to_app_to_idx,
is_zero = _is_zero(term)
assert is_zero in ['yes', 'no', 'maybe']
if is_zero == 'maybe':
msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to "
msg += "the cost through this op are via "
msg += "integer-valued outputs so it should be "
msg += "NullType, DisconnectedType, or some form "
msg += "of zeros. It is not NullType or "
msg += "DisconnectedType and theano can't "
msg += "simplify it to a constant, so it's not "
msg += "verifiably zeros."
msg = msg % (str(node.op), str(term),
str(type(term)), i)
if is_zero == 'no':
msg = "%s.grad returned %s of type %s for input"
msg += " %d. Since this input is only connected "
msg += "to integer-valued outputs, it should "
msg += "evaluate to zeros, but it evaluates to"
msg += "%s."
msg % (node.op, term, type(term), i,
theano.get_scalar_constant_value(term))
msg = ("%s.grad returned %s of type %s for input"
" %d. This input's only connections to "
"the cost through this op are via "
"integer-valued outputs so it should be "
"NullType, DisconnectedType, or some form "
"of zeros. It is not NullType or "
"DisconnectedType and theano can't "
"simplify it to a constant, so it's not "
"verifiably zeros.")
msg %= (node.op, term, type(term), i)
elif is_zero == 'no':
msg = ("%s.grad returned %s of type %s for input"
" %d. Since this input is only connected "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
"%s.")
msg %= (node.op, term, type(term), i,
theano.get_scalar_constant_value(term))
raise ValueError(msg)
# Check that op.connection_pattern matches the connectivity
# logic driving the op.grad method
for i, packed in enumerate(zip(inputs, input_grads,
inputs_connected)):
ipt, ig, connected = packed
for i, (ipt, ig, connected) in enumerate(
zip(inputs, input_grads, inputs_connected)
):
actually_connected = \
not isinstance(ig.type, DisconnectedType)
if actually_connected and not connected:
msg = "%s.grad returned %s of type %s for input %d."
msg += " Expected DisconnectedType instance based on "
msg += " the output of the op's connection_pattern "
msg += "method."
msg = msg % (str(node.op), str(ig), str(ig.type), i)
msg = ("%s.grad returned %s of type %s for input %d."
" Expected DisconnectedType instance based on "
" the output of the op's connection_pattern "
"method.")
msg %= (str(node.op), str(ig), str(ig.type), i)
raise TypeError(msg)
if connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input"
msg += " %d."
msg = msg % (str(node.op), i)
elif connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input %d."
msg %= (str(node.op), i)
if hasattr(node.op, 'connection_pattern'):
msg += ' Its connection_pattern method does not'
msg += ' allow this.'
msg += (' Its connection_pattern method does not'
' allow this.')
raise TypeError(msg)
else:
msg += ' You may want to implement a '
msg += 'connection_pattern method for it.'
msg += (' You may want to implement a '
'connection_pattern method for it.')
warnings.warn(msg)
# cache the result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论