提交 adb884e9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added eval method and tests

上级 85f9247a
...@@ -405,6 +405,26 @@ class Variable(Node): ...@@ -405,6 +405,26 @@ class Variable(Node):
stacklevel=2) stacklevel=2)
del self.fgraph del self.fgraph
def eval(self, inputs_to_values = None):
""" Evaluates this variable.
inputs_to_values: a dictionary mapping theano Variables to values.
"""
if inputs_to_values is None:
inputs_to_values = {}
if not hasattr(self, '_fn'):
self._fn_inputs = inputs_to_values.keys()
self._fn = theano.function(self._fn_inputs, self)
args = [ inputs_to_values[param] for param in self._fn_inputs ]
rval = self._fn(*args)
return rval
env = property(env_getter, env_setter, env_deleter) env = property(env_getter, env_setter, env_deleter)
......
...@@ -82,7 +82,7 @@ class X: ...@@ -82,7 +82,7 @@ class X:
return as_string(inputs, outputs, return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter, leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter) node_formatter = self.node_formatter)
class TestStr(X): class TestStr(X):
...@@ -151,7 +151,7 @@ class TestClone(X): ...@@ -151,7 +151,7 @@ class TestClone(X):
############ ############
def prenode(obj): def prenode(obj):
if isinstance(obj, Variable): if isinstance(obj, Variable):
if obj.owner: if obj.owner:
return [obj.owner] return [obj.owner]
if isinstance(obj, Apply): if isinstance(obj, Apply):
...@@ -290,3 +290,18 @@ class TestIsSameGraph(unittest.TestCase): ...@@ -290,3 +290,18 @@ class TestIsSameGraph(unittest.TestCase):
({y: x, t: z}, True))), ({y: x, t: z}, True))),
], ],
debug=False) debug=False)
################
# eval #
################
def test_eval():
x = tensor.scalar()
y = tensor.scalar()
z = x + y
result = z.eval({x : 1., y : 2.})
assert z == 3.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论