提交 82f6a14f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup Function.__call__

上级 f0a9ec25
......@@ -128,9 +128,6 @@ class DisconnectedType(Type):
" a symbolic placeholder."
)
def may_share_memory(a, b):
return False
def value_eq(a, b, force_same_dtype=True):
raise AssertionError(
"If you're assigning to a DisconnectedType you're"
......
......@@ -26,9 +26,6 @@ class NullType(Type):
def filter_variable(self, other, allow_convert=True):
raise ValueError("No values may be assigned to a NullType")
def may_share_memory(a, b):
return False
def values_eq(self, a, b, force_same_dtype=True):
raise ValueError("NullType has no values to compare")
......
......@@ -48,10 +48,7 @@ class Type(MetaObject, Generic[D]):
unique element (i.e. it uses `self.__eq__`).
"""
if self == otype:
return True
return False
return self == otype
def is_super(self, otype: "Type") -> bool | None:
"""Determine if `self` is a supertype of `otype`.
......
......@@ -303,13 +303,6 @@ class ScalarType(CType, HasDataType, HasShape):
dtype = self.dtype
return type(self)(dtype)
@staticmethod
def may_share_memory(a, b):
# This class represent basic c type, represented in python
# with numpy.scalar. They are read only. So from python, they
# can never share memory.
return False
def filter(self, data, strict=False, allow_downcast=None):
py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type):
......
......@@ -126,12 +126,6 @@ class NoneTypeT(Generic):
else:
raise TypeError("Expected None!")
@staticmethod
def may_share_memory(a, b):
# None never share memory between object, in the sense of DebugMode.
# Python None are singleton
return False
none_type_t = NoneTypeT()
......
......@@ -730,6 +730,8 @@ class TestFunction:
s1 = shared(b)
s2 = shared(b)
x1 = vector()
x2 = vector(shape=(3,))
x3 = vector(shape=(1,))
# Assert cases we should not check for aliased inputs
for d in [
......@@ -737,27 +739,29 @@ class TestFunction:
dict(outputs=[s1 + 1, s2 + 3]),
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(
inputs=[In(x1, mutable=True)], outputs=[x1 + 1], updates=[(s2, s2 + 3)]
),
dict(
inputs=[In(x2, mutable=True), In(x3, mutable=True)],
outputs=[x2 + 2, x3 + 3],
),
]:
if "inputs" not in d:
d["inputs"] = []
f = function(**d)
assert not f._check_for_aliased_inputs, d
assert not f._potential_aliased_input_groups, d
# Assert cases we should check for aliased inputs
for d in [
dict(
inputs=[In(x1, borrow=True)],
outputs=[x1 + 1],
updates=[(s2, s2 + 3)],
),
dict(
inputs=[In(x1, borrow=True, mutable=True)],
outputs=[x1 + 1],
inputs=[In(x1, mutable=True), In(x2, mutable=True)],
outputs=[x1 + 1, x2 + 2],
updates=[(s2, s2 + 3)],
),
dict(
inputs=[In(x1, mutable=True)],
outputs=[x1 + 1],
inputs=[In(x1, mutable=True), In(x3, mutable=True)],
outputs=[x1 + 1, x3 + 3],
updates=[(s2, s2 + 3)],
),
]:
......@@ -765,7 +769,7 @@ class TestFunction:
d["inputs"] = []
f = function(**d)
assert f._check_for_aliased_inputs, d
assert f._potential_aliased_input_groups, d
def test_output_dictionary(self):
# Tests that function works when outputs is a dictionary
......@@ -879,7 +883,7 @@ class TestPicklefunction:
f = function(
[
x,
In(a, value=1.0, name="a"),
In(a, value=1.0, name="a", mutable=True),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
......@@ -901,7 +905,12 @@ class TestPicklefunction:
assert x not in g.container
assert x not in g.value
assert len(f.defaults) == len(g.defaults)
assert f._check_for_aliased_inputs is g._check_for_aliased_inputs
# Shared variable is the first input
assert (
f._potential_aliased_input_groups
== g._potential_aliased_input_groups
== ((1, 2),)
)
assert f.name == g.name
assert f.maker.fgraph.name == g.maker.fgraph.name
# print(f"{f.defaults = }")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论