提交 a5eeb907 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

More generic implementation of get_test_value

This makes 'get_test_value' work on any object that can turned into a variable automatically, not just numpy ndarrays.
上级 ffa4e84c
...@@ -19,6 +19,7 @@ import utils ...@@ -19,6 +19,7 @@ import utils
import warnings import warnings
from env import Env from env import Env
import cc import cc
import theano
class CLinkerObject(object): class CLinkerObject(object):
...@@ -549,17 +550,16 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -549,17 +550,16 @@ class Op(utils.object2, PureOp, CLinkerOp):
def get_test_value(v): def get_test_value(v):
""" """
Extract test value from variable v. Raises AttributeError if there is none. Extract test value from `v`. Raises AttributeError if there is none.
If input `v` is not already a variable, it is turned into one by calling
`as_tensor_variable(v)`, so that this function can be applied e.g.
on numpy arrays or Python lists and scalars, considering them as constants.
For an ndarray, the value is the ndarray itself
For a Constant, the test value is v.value. For a Constant, the test value is v.value.
For a Shared variable, it is the internal value. For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.""" For another Variable, it is the content of v.tag.test_value.
"""
try: v_tensor = theano.tensor.as_tensor_variable(v)
return PureOp._get_test_value(v) return PureOp._get_test_value(v_tensor)
except AttributeError:
if hasattr(v,'__array__'):
return v
raise
...@@ -191,7 +191,7 @@ class TestMakeThunk(unittest.TestCase): ...@@ -191,7 +191,7 @@ class TestMakeThunk(unittest.TestCase):
def test_test_value_ndarray(): def test_test_value_ndarray():
x = numpy.zeros((5,5)) x = numpy.zeros((5,5))
v = op.get_test_value(x) v = op.get_test_value(x)
assert v is x assert (v == x).all()
def test_test_value_constant(): def test_test_value_constant():
x = T.as_tensor_variable(numpy.zeros((5,5))) x = T.as_tensor_variable(numpy.zeros((5,5)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论