提交 68db91e0 authored 作者: James Bergstra's avatar James Bergstra

added some pickling test cases to test_function

上级 96e48a0e
......@@ -13,7 +13,7 @@ from .. import gof
import sys
import copy
from mode import *
import mode as mode_module
from io import *
def infer_reuse_pattern(env, outputs_to_disown):
......@@ -551,7 +551,7 @@ class FunctionMaker(object):
raise TypeError("Unknown output type: %s (%s)", type(output), output)
def __init__(self, inputs, outputs,
mode = default_mode, accept_inplace = False, function_builder = Function):
mode = None, accept_inplace = False, function_builder = Function):
"""
:type inputs: a list of SymbolicInput instances
......@@ -560,12 +560,14 @@ class FunctionMaker(object):
case the functions produced by FunctionMaker will return
their output value directly
:param mode: a Mode instance telling FunctionMaker how to optimize and link
:param mode: a Mode instance telling FunctionMaker how to optimize and link. None
means to use the `default_mode`.
:param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
"""
mode = mode if mode is not None else mode_module.default_mode
# Handle the case where inputs and/or outputs is a single Variable (not in a list)
unpack_single = False
......@@ -586,7 +588,7 @@ class FunctionMaker(object):
self.env = env
# Fetch the mode and then the optimizer and linker
mode = predefined_modes.get(mode, mode)
mode = mode_module.predefined_modes.get(mode, mode)
optimizer, linker = mode.optimizer, copy.copy(mode.linker)
# optimize the env
......@@ -595,7 +597,7 @@ class FunctionMaker(object):
# initialize the linker
if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionFactory should be a Linker with an accept method " \
"or one of %s" % predefined_linkers.keys())
"or one of %s" % mode_module.predefined_linkers.keys())
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
no_borrow = [output for output, spec in zip(env.outputs, outputs+additional_outputs) if not spec.borrow]
......@@ -742,7 +744,7 @@ def register_checker(checker):
def function(inputs, outputs, mode=default_mode, accept_inplace = False):
def function(inputs, outputs, mode=None, accept_inplace = False):
"""
Return a function calculating the outputs from the inputs.
......@@ -752,8 +754,8 @@ def function(inputs, outputs, mode=default_mode, accept_inplace = False):
value of the returned function will match the format of this argument (either the value
itself or a list of one or more return values)
:param mode: a descriptive string or a Mode instance. (See below for descriptive string
list).
:param mode: a descriptive string or a Mode instance. (Default of None means to use
`mode.default_mode` (See below for descriptive string list).
Currently, the library provides the following mode strings:
......@@ -789,6 +791,7 @@ def function(inputs, outputs, mode=default_mode, accept_inplace = False):
f[<kitname>] = seed #re-seed the elements of a RandomKit
"""
mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs)
if outputs is None:
......@@ -798,7 +801,7 @@ def function(inputs, outputs, mode=default_mode, accept_inplace = False):
defaults = [getattr(input, 'value', None) for input in inputs]
mode = predefined_modes.get(mode, mode)
mode = mode_module.predefined_modes.get(mode, mode)
if isinstance(mode, (list, tuple)): # "mode comparison" semantics
if not mode:
raise ValueError("Please provide at least one mode.")
......
......@@ -408,6 +408,69 @@ class T_picklefunction(unittest.TestCase):
f(1,2) # put them out of sync
self.failIf(f(1, 2) == g(1, 2)) #they should not be equal anymore.
def test_pickle(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)
try:
g = cPickle.loads(cPickle.dumps(f))
except NotImplementedError, e:
if e[0].startswith('DebugMode is not picklable'):
return
else:
raise
#if they both return, assume that they return equivalent things.
#print [(k,id(k)) for k in f.finder.keys()]
#print [(k,id(k)) for k in g.finder.keys()]
self.failIf(g.container[0].storage is f.container[0].storage)
self.failIf(g.container[1].storage is f.container[1].storage)
self.failIf(g.container[2].storage is f.container[2].storage)
self.failIf(x in g.container)
self.failIf(x in g.value)
self.failIf(g.value[1] is f.value[1]) # should not have been copied
self.failIf(g.value[2] is f.value[2]) # should have been copied because it is mutable.
self.failIf((g.value[2] != f.value[2]).any()) # its contents should be identical
self.failUnless(f(2, 1) == g(2)) #they should be in sync, default value should be copied.
self.failUnless(f(2, 1) == g(2)) #they should be in sync, default value should be copied.
f(1,2) # put them out of sync
self.failIf(f(1, 2) == g(1, 2)) #they should not be equal anymore.
def test_optimizations_preserved(self):
a = T.dvector() # the a is for 'anonymous' (un-named).
x = T.dvector('x')
s = T.dvector('s')
xm = T.dmatrix('x')
sm = T.dmatrix('s')
f = function([a, x, s, xm, sm], ((a.T.T)*(tensor.dot(xm, (sm.T.T.T)) + x).T * (x/x) + s))
old_default_mode = compile.mode.default_mode
try:
str_f = cPickle.dumps(f)
compile.mode.default_mode = mode_module.Mode(linker='py', optimizer=None)
g = cPickle.loads(str_f)
#print g.maker.mode
#print compile.mode.default_mode
finally:
compile.mode.default_mode = old_default_mode
assert f.maker is not g.maker
assert f.maker.env is not g.maker.env
tf = f.maker.env.toposort()
tg = f.maker.env.toposort()
assert len(tf) == len(tg)
for nf, ng in zip(tf, tg):
assert nf.op == ng.op
assert len(nf.inputs) == len(ng.inputs)
assert len(nf.outputs) == len(ng.outputs)
assert [i.type for i in nf.inputs] == [i.type for i in ng.inputs]
assert [i.type for i in nf.outputs] == [i.type for i in ng.outputs]
# class T_function_examples(unittest.TestCase):
# def test_accumulator(self):
# """Test low-level interface with state."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论