提交 ab720fb5 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1735 from rizar/fix_eval

Fix Variable.eval() (issue #1724 on github)
......@@ -415,18 +415,21 @@ class Variable(Node):
if inputs_to_values is None:
inputs_to_values = {}
if not hasattr(self, '_fn'):
self._fn_inputs = inputs_to_values.keys()
self._fn = theano.function(self._fn_inputs, self)
args = [inputs_to_values[param] for param in self._fn_inputs]
if not hasattr(self, '_fn_cache'):
self._fn_cache = dict()
rval = self._fn(*args)
inputs = tuple(sorted(inputs_to_values.keys(), key=id))
if not inputs in self._fn_cache:
self._fn_cache[inputs] = theano.function(inputs, self)
args = [inputs_to_values[param] for param in inputs]
rval = self._fn_cache[inputs](*args)
return rval
def __getstate__(self):
d = self.__dict__.copy()
d.pop("_fn", None)
d.pop("_fn_cache", None)
return d
env = property(env_getter, env_setter, env_deleter)
......
......@@ -297,14 +297,17 @@ class TestIsSameGraph(unittest.TestCase):
# eval #
################
def test_eval():
x = tensor.scalar()
y = tensor.scalar()
z = x + y
result = z.eval({x : 1., y : 2.})
assert result == 3.
# We don't want to pickle the tmp function.
assert not hasattr(pickle.loads(pickle.dumps(z)), '_fn')
class TestEval(unittest.TestCase):
def setUp(self):
self.x, self.y = tensor.scalars('x', 'y')
self.z = self.x + self.y
self.w = 2 * self.z
def test_eval(self):
self.assertEquals(self.w.eval({self.x : 1., self.y : 2.}), 6.)
self.assertEquals(self.w.eval({self.z : 3}), 6.)
self.assertTrue(hasattr(self.w, "_fn_cache"),
"variable must have cache after eval")
self.assertFalse(hasattr(pickle.loads(pickle.dumps(self.w)), '_fn_cache'),
"temporary functions must not be serialized")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论