提交 0ea10588 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Move get_test_value method to the Variable (sub)classes

The test value computation steps were also extracted from `PureOp.__call__` and put into a stand-alone `compute_test_value` function.
上级 03b82c26
...@@ -150,6 +150,9 @@ class SharedVariable(Variable): ...@@ -150,6 +150,9 @@ class SharedVariable(Variable):
else: else:
self.container.value = copy.deepcopy(new_value) self.container.value = copy.deepcopy(new_value)
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False): def zero(self, borrow=False):
""" """
Set the values of a shared variable to 0. Set the values of a shared variable to 0.
......
...@@ -58,7 +58,14 @@ from theano.gof.link import ( ...@@ -58,7 +58,14 @@ from theano.gof.link import (
WrapLinkerMany, WrapLinkerMany,
) )
from theano.gof.op import Op, OpenMPOp, PureOp, COp, ops_with_inner_function from theano.gof.op import (
Op,
OpenMPOp,
PureOp,
COp,
ops_with_inner_function,
get_test_value,
)
from theano.gof.type import EnumType, EnumList, CEnumType from theano.gof.type import EnumType, EnumList, CEnumType
......
...@@ -383,19 +383,39 @@ class Variable(Node): ...@@ -383,19 +383,39 @@ class Variable(Node):
self.tag = utils.ValidatingScratchpad("test_value", type.filter) self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
self.owner = owner self.owner = owner
if index is not None and not isinstance(index, integer_types): if index is not None and not isinstance(index, integer_types):
raise TypeError("index must be an int", index) raise TypeError("index must be an int", index)
self.index = index self.index = index
if name is not None and not isinstance(name, string_types): if name is not None and not isinstance(name, string_types):
raise TypeError("name must be a string", name) raise TypeError("name must be a string", name)
self.name = name self.name = name
self.auto_name = "auto_" + str(next(self.__count__)) self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self) Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
Raises
------
AttributeError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise AttributeError(
"{} has no test value {}".format(self, detailed_err_msg)
)
return self.tag.test_value
def __str__(self): def __str__(self):
"""Return a str representation of the Variable.""" """Return a str representation of the Variable."""
if self.name is not None: if self.name is not None:
...@@ -583,6 +603,9 @@ class Constant(Variable): ...@@ -583,6 +603,9 @@ class Constant(Variable):
self.data = type.filter(data) self.data = type.filter(data)
utils.add_tag_trace(self) utils.add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id # this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature() return isinstance(other, Constant) and self.signature() == other.signature()
......
差异被折叠。
...@@ -523,7 +523,7 @@ def scan( ...@@ -523,7 +523,7 @@ def scan(
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice) nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except AttributeError as e: except AttributeError as e:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
...@@ -655,7 +655,7 @@ def scan( ...@@ -655,7 +655,7 @@ def scan(
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg) arg.tag.test_value = gof.get_test_value(actual_arg)
except AttributeError as e: except AttributeError as e:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
...@@ -716,7 +716,7 @@ def scan( ...@@ -716,7 +716,7 @@ def scan(
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
nw_slice.tag.test_value = gof.Op._get_test_value( nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice _init_out_var_slice
) )
except AttributeError as e: except AttributeError as e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论