提交 21daf4e0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Validate test values using a tensor's type

上级 74ee82b5
...@@ -167,14 +167,16 @@ class TestComputeTestValue: ...@@ -167,14 +167,16 @@ class TestComputeTestValue:
@theano.change_flags(compute_test_value="raise") @theano.change_flags(compute_test_value="raise")
def test_incorrect_type(self): def test_incorrect_type(self):
x = tt.fmatrix("x")
# Incorrect dtype (float64) for test_value
x.tag.test_value = np.random.rand(3, 4)
y = tt.dmatrix("y")
y.tag.test_value = np.random.rand(4, 5)
x = tt.vector("x")
with pytest.raises(TypeError): with pytest.raises(TypeError):
tt.dot(x, y) # Incorrect shape for test value
x.tag.test_value = np.empty((2, 2))
x = tt.fmatrix("x")
with pytest.raises(TypeError):
# Incorrect dtype (float64) for test value
x.tag.test_value = np.random.rand(3, 4)
@theano.change_flags(compute_test_value="raise") @theano.change_flags(compute_test_value="raise")
def test_overided_function(self): def test_overided_function(self):
......
...@@ -383,7 +383,7 @@ class Variable(Node): ...@@ -383,7 +383,7 @@ class Variable(Node):
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
super(Variable, self).__init__() super(Variable, self).__init__()
self.tag = utils.Scratchpad() self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
......
...@@ -259,6 +259,23 @@ class Scratchpad(object): ...@@ -259,6 +259,23 @@ class Scratchpad(object):
print(" %s: %s" % (k, v)) print(" %s: %s" % (k, v))
class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values."""
def __init__(self, attr, attr_filter):
super().__init__()
object.__setattr__(self, "attr", attr)
object.__setattr__(self, "attr_filter", attr_filter)
def __setattr__(self, attr, obj):
if getattr(self, "attr", None) == attr:
obj = self.attr_filter(obj)
return object.__setattr__(self, attr, obj)
class D: class D:
def __init__(self, **d): def __init__(self, **d):
self.__dict__.update(d) self.__dict__.update(d)
......
...@@ -7743,7 +7743,10 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -7743,7 +7743,10 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
if tv.size > 0: if tv.size > 0:
tmp.tag.test_value = tv.flatten()[0] tmp.tag.test_value = tv.flatten()[0]
else: else:
tmp.tag.test_value = tv _logger.warning(
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except AttributeError: except AttributeError:
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论