提交 7d8bcfed authored 作者: nouiz's avatar nouiz

Merge pull request #940 from goodfeli/eval

added eval method and tests
...@@ -116,6 +116,31 @@ is a single Variable *or* a list of Variables. For either case, the second ...@@ -116,6 +116,31 @@ is a single Variable *or* a list of Variables. For either case, the second
argument is what we want to see as output when we apply the function. *f* may argument is what we want to see as output when we apply the function. *f* may
then be used like a normal Python function. then be used like a normal Python function.
.. note::
As a shortcut, you can skip step 3, and just use a variable's
:func:`eval` method. The :func:`eval` method is not as flexible
as :func:`function` but it can do everything we've covered in
the tutorial so far. It has the added benefit of not requiring
you to import :func:`function` . Here is how :func:`eval` works:
>>> import theano.tensor as T
>>> x = T.dscalar('x')
>>> y = T.dscalar('y')
>>> z = x + y
>>> z.eval({x : 16.3, y : 12.1})
array(28.4)
We passed :func:`eval` a dictionary mapping symbolic theano
variables to the values to substitute for them, and it returned
the numerical value of the expression.
:func:`eval` will be slow the first time you call it on a variable--
it needs to call :func:`function` to compile the expression behind
the scenes. Subsequent calls to :func:`eval` on that same variable
will be fast, because the variable caches the compiled function.
Adding two Matrices Adding two Matrices
=================== ===================
......
...@@ -405,6 +405,26 @@ class Variable(Node): ...@@ -405,6 +405,26 @@ class Variable(Node):
stacklevel=2) stacklevel=2)
del self.fgraph del self.fgraph
def eval(self, inputs_to_values = None):
""" Evaluates this variable.
inputs_to_values: a dictionary mapping theano Variables to values.
"""
if inputs_to_values is None:
inputs_to_values = {}
if not hasattr(self, '_fn'):
self._fn_inputs = inputs_to_values.keys()
self._fn = theano.function(self._fn_inputs, self)
args = [ inputs_to_values[param] for param in self._fn_inputs ]
rval = self._fn(*args)
return rval
env = property(env_getter, env_setter, env_deleter) env = property(env_getter, env_setter, env_deleter)
......
...@@ -82,7 +82,7 @@ class X: ...@@ -82,7 +82,7 @@ class X:
return as_string(inputs, outputs, return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter, leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter) node_formatter = self.node_formatter)
class TestStr(X): class TestStr(X):
...@@ -151,7 +151,7 @@ class TestClone(X): ...@@ -151,7 +151,7 @@ class TestClone(X):
############ ############
def prenode(obj): def prenode(obj):
if isinstance(obj, Variable): if isinstance(obj, Variable):
if obj.owner: if obj.owner:
return [obj.owner] return [obj.owner]
if isinstance(obj, Apply): if isinstance(obj, Apply):
...@@ -290,3 +290,18 @@ class TestIsSameGraph(unittest.TestCase): ...@@ -290,3 +290,18 @@ class TestIsSameGraph(unittest.TestCase):
({y: x, t: z}, True))), ({y: x, t: z}, True))),
], ],
debug=False) debug=False)
################
# eval #
################
def test_eval():
x = tensor.scalar()
y = tensor.scalar()
z = x + y
result = z.eval({x : 1., y : 2.})
assert z == 3.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论