提交 0ea10588 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Move get_test_value method to the Variable (sub)classes

The test value computation steps were also extracted from `PureOp.__call__` and put into a stand-alone `compute_test_value` function.
上级 03b82c26
......@@ -150,6 +150,9 @@ class SharedVariable(Variable):
else:
self.container.value = copy.deepcopy(new_value)
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
......
......@@ -58,7 +58,14 @@ from theano.gof.link import (
WrapLinkerMany,
)
from theano.gof.op import Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.op import (
Op,
OpenMPOp,
PureOp,
COp,
ops_with_inner_function,
get_test_value,
)
from theano.gof.type import EnumType, EnumList, CEnumType
......
......@@ -383,19 +383,39 @@ class Variable(Node):
self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
if index is not None and not isinstance(index, integer_types):
raise TypeError("index must be an int", index)
self.index = index
if name is not None and not isinstance(name, string_types):
raise TypeError("name must be a string", name)
self.name = name
self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
Raises
------
AttributeError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise AttributeError(
"{} has no test value {}".format(self, detailed_err_msg)
)
return self.tag.test_value
def __str__(self):
"""Return a str representation of the Variable."""
if self.name is not None:
......@@ -583,6 +603,9 @@ class Constant(Variable):
self.data = type.filter(data)
utils.add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
......
差异被折叠。
......@@ -523,7 +523,7 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice)
nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except AttributeError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
......@@ -655,7 +655,7 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg)
arg.tag.test_value = gof.get_test_value(actual_arg)
except AttributeError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
......@@ -716,7 +716,7 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(
nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice
)
except AttributeError as e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论