提交 82f6a14f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup Function.__call__

上级 f0a9ec25
...@@ -128,9 +128,6 @@ class DisconnectedType(Type): ...@@ -128,9 +128,6 @@ class DisconnectedType(Type):
" a symbolic placeholder." " a symbolic placeholder."
) )
def may_share_memory(a, b):
return False
def value_eq(a, b, force_same_dtype=True): def value_eq(a, b, force_same_dtype=True):
raise AssertionError( raise AssertionError(
"If you're assigning to a DisconnectedType you're" "If you're assigning to a DisconnectedType you're"
......
...@@ -26,9 +26,6 @@ class NullType(Type): ...@@ -26,9 +26,6 @@ class NullType(Type):
def filter_variable(self, other, allow_convert=True): def filter_variable(self, other, allow_convert=True):
raise ValueError("No values may be assigned to a NullType") raise ValueError("No values may be assigned to a NullType")
def may_share_memory(a, b):
return False
def values_eq(self, a, b, force_same_dtype=True): def values_eq(self, a, b, force_same_dtype=True):
raise ValueError("NullType has no values to compare") raise ValueError("NullType has no values to compare")
......
...@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]): ...@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]):
unique element (i.e. it uses `self.__eq__`). unique element (i.e. it uses `self.__eq__`).
""" """
if self == otype: return self == otype
return True
return False
def is_super(self, otype: "Type") -> bool | None: def is_super(self, otype: "Type") -> bool | None:
"""Determine if `self` is a supertype of `otype`. """Determine if `self` is a supertype of `otype`.
......
...@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape):
dtype = self.dtype dtype = self.dtype
return type(self)(dtype) return type(self)(dtype)
@staticmethod
def may_share_memory(a, b):
# This class represent basic c type, represented in python
# with numpy.scalar. They are read only. So from python, they
# can never share memory.
return False
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
......
...@@ -126,12 +126,6 @@ class NoneTypeT(Generic): ...@@ -126,12 +126,6 @@ class NoneTypeT(Generic):
else: else:
raise TypeError("Expected None!") raise TypeError("Expected None!")
@staticmethod
def may_share_memory(a, b):
# None never share memory between object, in the sense of DebugMode.
# Python None are singleton
return False
none_type_t = NoneTypeT() none_type_t = NoneTypeT()
......
...@@ -730,6 +730,8 @@ class TestFunction: ...@@ -730,6 +730,8 @@ class TestFunction:
s1 = shared(b) s1 = shared(b)
s2 = shared(b) s2 = shared(b)
x1 = vector() x1 = vector()
x2 = vector(shape=(3,))
x3 = vector(shape=(1,))
# Assert cases we should not check for aliased inputs # Assert cases we should not check for aliased inputs
for d in [ for d in [
...@@ -737,27 +739,29 @@ class TestFunction: ...@@ -737,27 +739,29 @@ class TestFunction:
dict(outputs=[s1 + 1, s2 + 3]), dict(outputs=[s1 + 1, s2 + 3]),
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]), dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]), dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(
inputs=[In(x1, mutable=True)], outputs=[x1 + 1], updates=[(s2, s2 + 3)]
),
dict(
inputs=[In(x2, mutable=True), In(x3, mutable=True)],
outputs=[x2 + 2, x3 + 3],
),
]: ]:
if "inputs" not in d: if "inputs" not in d:
d["inputs"] = [] d["inputs"] = []
f = function(**d) f = function(**d)
assert not f._check_for_aliased_inputs, d assert not f._potential_aliased_input_groups, d
# Assert cases we should check for aliased inputs # Assert cases we should check for aliased inputs
for d in [ for d in [
dict( dict(
inputs=[In(x1, borrow=True)], inputs=[In(x1, mutable=True), In(x2, mutable=True)],
outputs=[x1 + 1], outputs=[x1 + 1, x2 + 2],
updates=[(s2, s2 + 3)],
),
dict(
inputs=[In(x1, borrow=True, mutable=True)],
outputs=[x1 + 1],
updates=[(s2, s2 + 3)], updates=[(s2, s2 + 3)],
), ),
dict( dict(
inputs=[In(x1, mutable=True)], inputs=[In(x1, mutable=True), In(x3, mutable=True)],
outputs=[x1 + 1], outputs=[x1 + 1, x3 + 3],
updates=[(s2, s2 + 3)], updates=[(s2, s2 + 3)],
), ),
]: ]:
...@@ -765,7 +769,7 @@ class TestFunction: ...@@ -765,7 +769,7 @@ class TestFunction:
d["inputs"] = [] d["inputs"] = []
f = function(**d) f = function(**d)
assert f._check_for_aliased_inputs, d assert f._potential_aliased_input_groups, d
def test_output_dictionary(self): def test_output_dictionary(self):
# Tests that function works when outputs is a dictionary # Tests that function works when outputs is a dictionary
...@@ -879,7 +883,7 @@ class TestPicklefunction: ...@@ -879,7 +883,7 @@ class TestPicklefunction:
f = function( f = function(
[ [
x, x,
In(a, value=1.0, name="a"), In(a, value=1.0, name="a", mutable=True),
In(s, value=0.0, update=s + a * x, mutable=True), In(s, value=0.0, update=s + a * x, mutable=True),
], ],
s + a * x, s + a * x,
...@@ -901,7 +905,12 @@ class TestPicklefunction: ...@@ -901,7 +905,12 @@ class TestPicklefunction:
assert x not in g.container assert x not in g.container
assert x not in g.value assert x not in g.value
assert len(f.defaults) == len(g.defaults) assert len(f.defaults) == len(g.defaults)
assert f._check_for_aliased_inputs is g._check_for_aliased_inputs # Shared variable is the first input
assert (
f._potential_aliased_input_groups
== g._potential_aliased_input_groups
== ((1, 2),)
)
assert f.name == g.name assert f.name == g.name
assert f.maker.fgraph.name == g.maker.fgraph.name assert f.maker.fgraph.name == g.maker.fgraph.name
# print(f"{f.defaults = }") # print(f"{f.defaults = }")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论