提交 7b448dce authored 作者: notoraptor's avatar notoraptor

Update wrapper doc.

Simplify values_eq and values_eq_approx. Then, update test for values_eq_approx. Fix some exception type. Remove useless comment.
上级 25040e3f
...@@ -217,10 +217,12 @@ class TestWrapper(TestCase): ...@@ -217,10 +217,12 @@ class TestWrapper(TestCase):
assert w.values_eq_approx(o1, o2) assert w.values_eq_approx(o1, o2)
# Check value_eq_approx. # Check value_eq_approx.
# NB: I don't know exactly which kind of differences is rejected by values_eq but accepted by values_eq_approx.
# So, I just play a little with float values.
o3 = Wrap(w, o3 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('float32'), a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'), a2=(random_tensor.astype('float32') * 10 / 2.2 * 2.19999999999 / 10).astype('float64'),
a3=2000.0) a3=2000.0 - 0.00000000000000001)
assert w.values_eq_approx(o1, o3) assert w.values_eq_approx(o1, o3)
def test_op_params(self): def test_op_params(self):
......
...@@ -142,16 +142,16 @@ class Wrapper(Type): ...@@ -142,16 +142,16 @@ class Wrapper(Type):
structObject.key structObject.key
In a C code, attributes created to represent an instance of the type associated to ``key`` will be available via: In a C code, any attribute named ``key`` will be available via:
.. code-block:: c .. code-block:: c
structObject->key; structObject->key;
structObject->dtype_key; // e.g. from TensorType C code.
structObject->other_attribute_named_from_key;
/* etc. */
**NB**: This Type is not a complete type and should never be used for regular graph operations. .. note::
This Type is not complete and should never be used for regular graph operations.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -160,7 +160,7 @@ class Wrapper(Type): ...@@ -160,7 +160,7 @@ class Wrapper(Type):
for attribute_name in kwargs: for attribute_name in kwargs:
if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None: if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None:
raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name) raise AttributeError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name)
if attribute_name in c_cpp_keywords: if attribute_name in c_cpp_keywords:
print(len(c_cpp_keywords)) print(len(c_cpp_keywords))
raise SyntaxError('Wrapper: "%s" is a potential C/C++ keyword and should not be used as attribute name.' raise SyntaxError('Wrapper: "%s" is a potential C/C++ keyword and should not be used as attribute name.'
...@@ -180,7 +180,6 @@ class Wrapper(Type): ...@@ -180,7 +180,6 @@ class Wrapper(Type):
return 'Wrapper<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)]) return 'Wrapper<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)])
def __eq__(self, other): def __eq__(self, other):
# To be checked.
return (type(self) == type(other) and self.fields == other.fields and self.types == other.types) return (type(self) == type(other) and self.fields == other.fields and self.types == other.types)
def __hash__(self): def __hash__(self):
...@@ -211,18 +210,10 @@ class Wrapper(Type): ...@@ -211,18 +210,10 @@ class Wrapper(Type):
return self.wrap_data(data, strict, allow_downcast) return self.wrap_data(data, strict, allow_downcast)
def values_eq(self, a, b): def values_eq(self, a, b):
# We check that a and b have expected attributes and strict values.
a = self.filter(a, strict=True)
b = self.filter(b, strict=True)
# Then we compare.
return all(self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])) return all(self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
for i in range(self.length)) for i in range(self.length))
def values_eq_approx(self, a, b): def values_eq_approx(self, a, b):
# We check, wrap and round a and b if necessary.
a = self.filter(a, strict=False, allow_downcast=True)
b = self.filter(b, strict=False, allow_downcast=True)
# Then we compare.
return all(self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])) return all(self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
for i in range(self.length)) for i in range(self.length))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论