提交 1cbb3cb8 authored 作者: nouiz's avatar nouiz

Merge pull request #59 from goodfeli/get_test_value

Added the get_test_value function
...@@ -344,7 +344,7 @@ class PureOp(object): ...@@ -344,7 +344,7 @@ class PureOp(object):
# ensure that the test value is correct # ensure that the test value is correct
return v.type.filter(v.tag.test_value) return v.type.filter(v.tag.test_value)
raise AttributeError('%s has not test value' % v) raise AttributeError('%s has no test value' % v)
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
"""Optional: Return some or all output[s] of `make_node`. """Optional: Return some or all output[s] of `make_node`.
...@@ -545,3 +545,21 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -545,3 +545,21 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.perform = p rval.perform = p
rval.lazy = False rval.lazy = False
return rval return rval
def get_test_value(v):
"""
Extract test value from variable v. Raises AttributeError if there is none.
For an ndarray, the value is the ndarray itself
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value."""
try:
return PureOp._get_test_value(v)
except AttributeError:
if hasattr(v,'__array__'):
return v
raise
...@@ -5,11 +5,17 @@ import numpy ...@@ -5,11 +5,17 @@ import numpy
import theano import theano
from theano.gof.op import * import theano.gof.op as op
from theano.gof.type import Type, Generic from theano.gof.type import Type, Generic
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
import theano.tensor as T
from theano import scalar from theano import scalar
from theano import shared
config = theano.config
Op = op.Op
utils = op.utils
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
...@@ -182,5 +188,33 @@ class TestMakeThunk(unittest.TestCase): ...@@ -182,5 +188,33 @@ class TestMakeThunk(unittest.TestCase):
assert compute_map[o][0] assert compute_map[o][0]
assert storage_map[o][0] == 4 assert storage_map[o][0] == 4
def test_test_value_ndarray():
x = numpy.zeros((5,5))
v = op.get_test_value(x)
assert v is x
def test_test_value_constant():
x = T.as_tensor_variable(numpy.zeros((5,5)))
v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5)))
def test_test_value_shared():
x = shared(numpy.zeros((5,5)))
v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5)))
def test_test_value_op():
try:
prev_value = config.compute_test_value
config.compute_test_value = 'raise'
x = T.log(numpy.ones((5,5)))
v = op.get_test_value(x)
assert numpy.allclose(v, numpy.zeros((5,5)))
finally:
config.compute_test_value = prev_value
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论