提交 21daf4e0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Validate test values using a tensor's type

上级 74ee82b5
......@@ -167,14 +167,16 @@ class TestComputeTestValue:
@theano.change_flags(compute_test_value="raise")
def test_incorrect_type(self):
x = tt.fmatrix("x")
# Incorrect dtype (float64) for test_value
x.tag.test_value = np.random.rand(3, 4)
y = tt.dmatrix("y")
y.tag.test_value = np.random.rand(4, 5)
x = tt.vector("x")
with pytest.raises(TypeError):
tt.dot(x, y)
# Incorrect shape for test value
x.tag.test_value = np.empty((2, 2))
x = tt.fmatrix("x")
with pytest.raises(TypeError):
# Incorrect dtype (float64) for test value
x.tag.test_value = np.random.rand(3, 4)
@theano.change_flags(compute_test_value="raise")
def test_overided_function(self):
......
......@@ -383,7 +383,7 @@ class Variable(Node):
def __init__(self, type, owner=None, index=None, name=None):
super(Variable, self).__init__()
self.tag = utils.Scratchpad()
self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type
if owner is not None and not isinstance(owner, Apply):
......
......@@ -259,6 +259,23 @@ class Scratchpad(object):
print(" %s: %s" % (k, v))
class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values."""
def __init__(self, attr, attr_filter):
super().__init__()
object.__setattr__(self, "attr", attr)
object.__setattr__(self, "attr_filter", attr_filter)
def __setattr__(self, attr, obj):
if getattr(self, "attr", None) == attr:
obj = self.attr_filter(obj)
return object.__setattr__(self, attr, obj)
class D:
def __init__(self, **d):
self.__dict__.update(d)
......
......@@ -7743,7 +7743,10 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
if tv.size > 0:
tmp.tag.test_value = tv.flatten()[0]
else:
tmp.tag.test_value = tv
_logger.warning(
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except AttributeError:
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论