提交 a5eeb907 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

More generic implementation of get_test_value

This makes 'get_test_value' work on any object that can turned into a variable automatically, not just numpy ndarrays.
上级 ffa4e84c
......@@ -19,6 +19,7 @@ import utils
import warnings
from env import Env
import cc
import theano
class CLinkerObject(object):
......@@ -549,17 +550,16 @@ class Op(utils.object2, PureOp, CLinkerOp):
def get_test_value(v):
"""
Extract test value from variable v. Raises AttributeError if there is none.
Extract test value from `v`. Raises AttributeError if there is none.
If input `v` is not already a variable, it is turned into one by calling
`as_tensor_variable(v)`, so that this function can be applied e.g.
on numpy arrays or Python lists and scalars, considering them as constants.
For an ndarray, the value is the ndarray itself
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value."""
try:
return PureOp._get_test_value(v)
except AttributeError:
if hasattr(v,'__array__'):
return v
raise
For another Variable, it is the content of v.tag.test_value.
"""
v_tensor = theano.tensor.as_tensor_variable(v)
return PureOp._get_test_value(v_tensor)
......@@ -191,7 +191,7 @@ class TestMakeThunk(unittest.TestCase):
def test_test_value_ndarray():
x = numpy.zeros((5,5))
v = op.get_test_value(x)
assert v is x
assert (v == x).all()
def test_test_value_constant():
x = T.as_tensor_variable(numpy.zeros((5,5)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论