提交 48f5e4a9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove deprecated Linker.make_function

上级 9e3c69be
...@@ -10,7 +10,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -10,7 +10,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.graph.utils import MetaObject from aesara.graph.utils import MetaObject
from aesara.link.utils import gc_helper, map_storage, raise_with_op, streamline from aesara.link.utils import gc_helper, map_storage, raise_with_op, streamline
from aesara.utils import deprecated, difference, to_return_values from aesara.utils import difference
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -203,50 +203,6 @@ class Linker(ABC): ...@@ -203,50 +203,6 @@ class Linker(ABC):
""" """
@deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single: bool = True, **kwargs) -> Callable:
"""
Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the
outputs of that fgraph. If inplace is True, the calculations will
operate in the same storage the fgraph uses, else independent storage
will be allocated for the function.
Parameters
----------
unpack_single : bool
If `unpack_single` is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
Examples
--------
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn = MyLinker(fgraph).make_function(inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args):
takes = len(inputs)
got = len(args)
if got != takes:
raise TypeError(
f"Function call takes exactly {takes} args ({got} given)"
)
for arg, variable in zip(args, inputs):
variable.data = arg
thunk()
if unpack_single:
return to_return_values([variable.data for variable in outputs])
else:
return [variable.data for variable in outputs]
return execute
def schedule(self, fgraph: FunctionGraph) -> List[Apply]: def schedule(self, fgraph: FunctionGraph) -> List[Apply]:
"""Runs the scheduler (if set) or the toposort on the FunctionGraph. """Runs the scheduler (if set) or the toposort on the FunctionGraph.
......
...@@ -583,12 +583,8 @@ def struct_variable_codeblocks(fgraph, variable, policies, id, symbol_table, sub ...@@ -583,12 +583,8 @@ def struct_variable_codeblocks(fgraph, variable, policies, id, symbol_table, sub
class CLinker(Linker): class CLinker(Linker):
""" r"""Generates and compiles C code for a :class:`FunctionGraph`.
Creates C code for an fgraph, compiles it and returns callables
through make_thunk and make_function that make use of the compiled
code.
no_recycling can contain a list of Variables that belong to the fgraph.
If a Variable is in no_recycling, CLinker will clear the output storage If a Variable is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
...@@ -599,9 +595,10 @@ class CLinker(Linker): ...@@ -599,9 +595,10 @@ class CLinker(Linker):
super().__init__(scheduler=schedule) super().__init__(scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" r"""Associate this `Linker` with `fgraph`.
Associate linker with fgraph
The `no_recycling` argument can contain a list of `Variable`\s that
belong to `fgraph`.
""" """
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
...@@ -614,11 +611,7 @@ class CLinker(Linker): ...@@ -614,11 +611,7 @@ class CLinker(Linker):
return self return self
def fetch_variables(self): def fetch_variables(self):
""" """Fills the inputs, outputs, variables, orphans, temps and node_order fields."""
Fills the inputs, outputs, variables, orphans, temps and node_order
fields.
"""
fgraph = self.fgraph fgraph = self.fgraph
self.inputs = fgraph.inputs self.inputs = fgraph.inputs
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
...@@ -678,15 +671,16 @@ class CLinker(Linker): ...@@ -678,15 +671,16 @@ class CLinker(Linker):
) )
def code_gen(self): def code_gen(self):
""" """Construct and populate a C ``struct`` for the generated code.
Generates code for a struct that does the computation of the fgraph and
stores it in the struct_code field of the instance. Generates code for a ``struct`` instance that does the computation of the `FunctionGraph` and
stores it in the ``struct_code`` field of the instance.
If reuse_storage is True, outputs and temporaries will be stored in
the struct so they can be reused each time a function returned by If :attr:`CLinker.reuse_storage` is ``True``, outputs and temporaries
make_function is called, which means that the output of a call will will be stored in the ``struct`` so they can be reused each time the
be invalidated by the next. If reuse_storage is False, that problem generated code is called, which means that the output of a call will be
is avoided. invalidated by the next. If the value is ``False``, that problem is
avoided.
This method caches its computations. This method caches its computations.
......
...@@ -32,7 +32,7 @@ A good, simple way to do it would be to have those commands as methods of a stru ...@@ -32,7 +32,7 @@ A good, simple way to do it would be to have those commands as methods of a stru
>>> a, b, c = Tensor(), Tensor(), Tensor() >>> a, b, c = Tensor(), Tensor(), Tensor()
>>> d = b * c >>> d = b * c
>>> e = a + d >>> e = a + d
>>> debug = DebugLinker(FunctionGraph([a, b, c], [e])).make_function() >>> debug = make_function(DebugLinker(FunctionGraph([a, b, c], [e])))
>>> debug.set_breakpoint(d) >>> debug.set_breakpoint(d)
>>> debug.debug(10, 20, 30) # a, b, c = 10, 20, 30 >>> debug.debug(10, 20, 30) # a, b, c = 10, 20, 30
Now at: Mul(b, c) Now at: Mul(b, c)
...@@ -54,5 +54,3 @@ Finished. ...@@ -54,5 +54,3 @@ Finished.
[630] [630]
>>> >>>
}}} }}}
...@@ -12,6 +12,7 @@ from aesara.graph.type import CType ...@@ -12,6 +12,7 @@ from aesara.graph.type import CType
from aesara.link.basic import PerformLinker from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, DualLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, DualLinker, OpWiseCLinker
from aesara.tensor.type import iscalar, matrix, vector from aesara.tensor.type import iscalar, matrix, vector
from tests.link.test_link import make_function
def as_variable(x): def as_variable(x):
...@@ -189,7 +190,7 @@ def test_clinker_straightforward(): ...@@ -189,7 +190,7 @@ def test_clinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = CLinker().accept(FunctionGraph([x, y, z], [e])) lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
...@@ -214,7 +215,7 @@ def test_clinker_literal_inlining(): ...@@ -214,7 +215,7 @@ def test_clinker_literal_inlining():
z = Constant(tdouble, 4.12345678) z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = CLinker().accept(FunctionGraph([x, y], [e])) lnk = CLinker().accept(FunctionGraph([x, y], [e]))
fn = lnk.make_function() fn = make_function(lnk)
assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9 assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9
code = lnk.code_gen() code = lnk.code_gen()
# print "=== Code generated ===" # print "=== Code generated ==="
...@@ -257,7 +258,7 @@ def test_clinker_single_node(): ...@@ -257,7 +258,7 @@ def test_clinker_single_node():
x, y, z = inputs() x, y, z = inputs()
node = add.make_node(x, y) node = add.make_node(x, y)
lnk = CLinker().accept(FunctionGraph(node.inputs, node.outputs)) lnk = CLinker().accept(FunctionGraph(node.inputs, node.outputs))
fn = lnk.make_function() fn = make_function(lnk)
assert fn(2.0, 7.0) == 9 assert fn(2.0, 7.0) == 9
...@@ -269,7 +270,7 @@ def test_clinker_dups(): ...@@ -269,7 +270,7 @@ def test_clinker_dups():
x, y, z = inputs() x, y, z = inputs()
e = add(x, x) e = add(x, x)
lnk = CLinker().accept(FunctionGraph([x, x], [e])) lnk = CLinker().accept(FunctionGraph([x, x], [e]))
fn = lnk.make_function() fn = make_function(lnk)
assert fn(2.0, 2.0) == 4 assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
...@@ -282,7 +283,7 @@ def test_clinker_not_used_inputs(): ...@@ -282,7 +283,7 @@ def test_clinker_not_used_inputs():
x, y, z = inputs() x, y, z = inputs()
e = add(x, y) e = add(x, y)
lnk = CLinker().accept(FunctionGraph([x, y, z], [e])) lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
assert fn(2.0, 1.5, 1.0) == 3.5 assert fn(2.0, 1.5, 1.0) == 3.5
...@@ -294,7 +295,7 @@ def test_clinker_dups_inner(): ...@@ -294,7 +295,7 @@ def test_clinker_dups_inner():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(y, y), add(x, z)) e = add(mul(y, y), add(x, z))
lnk = CLinker().accept(FunctionGraph([x, y, z], [e])) lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
assert fn(1.0, 2.0, 3.0) == 8.0 assert fn(1.0, 2.0, 3.0) == 8.0
...@@ -303,7 +304,7 @@ def test_opwiseclinker_straightforward(): ...@@ -303,7 +304,7 @@ def test_opwiseclinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = OpWiseCLinker().accept(FunctionGraph([x, y, z], [e])) lnk = OpWiseCLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
if config.cxx: if config.cxx:
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
else: else:
...@@ -316,7 +317,7 @@ def test_opwiseclinker_constant(): ...@@ -316,7 +317,7 @@ def test_opwiseclinker_constant():
x = Constant(tdouble, 7.2, name="x") x = Constant(tdouble, 7.2, name="x")
e = add(mul(x, y), mul(y, z)) e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e])) lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
res = fn(1.5, 3.0) res = fn(1.5, 3.0)
assert res == 15.3 assert res == 15.3
...@@ -334,7 +335,7 @@ def test_duallinker_straightforward(): ...@@ -334,7 +335,7 @@ def test_duallinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker=_my_checker).accept(FunctionGraph([x, y, z], [e])) lnk = DualLinker(checker=_my_checker).accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
assert res == 15.3 assert res == 15.3
...@@ -348,15 +349,15 @@ def test_duallinker_mismatch(): ...@@ -348,15 +349,15 @@ def test_duallinker_mismatch():
e = bad_sub(mul(x, y), mul(y, z)) e = bad_sub(mul(x, y), mul(y, z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
lnk = DualLinker(checker=_my_checker).accept(g) lnk = DualLinker(checker=_my_checker).accept(g)
fn = lnk.make_function() fn = make_function(lnk)
# good # good
assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 assert make_function(CLinker().accept(g))(1.0, 2.0, 3.0) == -4.0
# good # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 assert make_function(OpWiseCLinker().accept(g))(1.0, 2.0, 3.0) == -4.0
# (purposely) wrong # (purposely) wrong
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 assert make_function(PerformLinker().accept(g))(1.0, 2.0, 3.0) == -10.0
with pytest.raises(MyExc): with pytest.raises(MyExc):
# this runs OpWiseCLinker and PerformLinker in parallel and feeds # this runs OpWiseCLinker and PerformLinker in parallel and feeds
...@@ -389,7 +390,7 @@ def test_c_fail_error(): ...@@ -389,7 +390,7 @@ def test_c_fail_error():
x = Constant(tdouble, 7.2, name="x") x = Constant(tdouble, 7.2, name="x")
e = add_fail(mul(x, y), mul(y, z)) e = add_fail(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e])) lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e]))
fn = lnk.make_function() fn = make_function(lnk)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
fn(1.5, 3.0) fn(1.5, 3.0)
......
from copy import deepcopy from copy import deepcopy
from typing import Callable
import numpy as np import numpy as np
...@@ -8,10 +9,52 @@ from aesara.graph import fg ...@@ -8,10 +9,52 @@ from aesara.graph import fg
from aesara.graph.basic import Apply, Constant, Variable, clone from aesara.graph.basic import Apply, Constant, Variable, clone
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.basic import Container, PerformLinker, WrapLinker from aesara.link.basic import Container, Linker, PerformLinker, WrapLinker
from aesara.link.c.basic import OpWiseCLinker from aesara.link.c.basic import OpWiseCLinker
from aesara.tensor.type import matrix, scalar from aesara.tensor.type import matrix, scalar
from aesara.utils import cmp from aesara.utils import cmp, to_return_values
def make_function(linker: Linker, unpack_single: bool = True, **kwargs) -> Callable:
"""
Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the
outputs of that fgraph. If inplace is True, the calculations will
operate in the same storage the fgraph uses, else independent storage
will be allocated for the function.
Parameters
----------
unpack_single : bool
If `unpack_single` is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
Examples
--------
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn = make_function(MyLinker(fgraph), inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
thunk, inputs, outputs = linker.make_thunk(**kwargs)
def execute(*args):
takes = len(inputs)
got = len(args)
if got != takes:
raise TypeError(f"Function call takes exactly {takes} args ({got} given)")
for arg, variable in zip(args, inputs):
variable.data = arg
thunk()
if unpack_single:
return to_return_values([variable.data for variable in outputs])
else:
return [variable.data for variable in outputs]
return execute
def as_variable(x): def as_variable(x):
...@@ -103,26 +146,26 @@ class TestPerformLinker: ...@@ -103,26 +146,26 @@ class TestPerformLinker:
def test_function(self): def test_function(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn = perform_linker(FunctionGraph([x, y, z], [e])).make_function() fn = make_function(perform_linker(FunctionGraph([x, y, z], [e])))
assert fn(1.0, 2.0, 3.0) == 1.5 assert fn(1.0, 2.0, 3.0) == 1.5
def test_constant(self): def test_constant(self):
x, y, z = inputs() x, y, z = inputs()
y = Constant(tdouble, 2.0) y = Constant(tdouble, 2.0)
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn = perform_linker(FunctionGraph([x], [e])).make_function() fn = make_function(perform_linker(FunctionGraph([x], [e])))
assert fn(1.0) == 1.5 assert fn(1.0) == 1.5
def test_input_output_same(self): def test_input_output_same(self):
x, y, z = inputs() x, y, z = inputs()
fn = perform_linker(FunctionGraph([x], [x])).make_function() fn = make_function(perform_linker(FunctionGraph([x], [x])))
assert 1.0 == fn(1.0) assert 1.0 == fn(1.0)
def test_input_dependency0(self): def test_input_dependency0(self):
x, y, z = inputs() x, y, z = inputs()
a, d = add(x, y), div(x, y) a, d = add(x, y), div(x, y)
e = mul(a, d) e = mul(a, d)
fn = perform_linker(FunctionGraph(*clone([x, y, a], [e]))).make_function() fn = make_function(perform_linker(FunctionGraph(*clone([x, y, a], [e]))))
assert fn(1.0, 2.0, 9.0) == 4.5 assert fn(1.0, 2.0, 9.0) == 4.5
def test_skiphole(self): def test_skiphole(self):
...@@ -130,7 +173,7 @@ class TestPerformLinker: ...@@ -130,7 +173,7 @@ class TestPerformLinker:
a = add(x, y) a = add(x, y)
r = raise_err(a) r = raise_err(a)
e = add(r, a) e = add(r, a)
fn = perform_linker(FunctionGraph(*clone([x, y, r], [e]))).make_function() fn = make_function(perform_linker(FunctionGraph(*clone([x, y, r], [e]))))
assert fn(1.0, 2.0, 4.5) == 7.5 assert fn(1.0, 2.0, 4.5) == 7.5
......
...@@ -66,13 +66,14 @@ from aesara.scalar.basic import ( ...@@ -66,13 +66,14 @@ from aesara.scalar.basic import (
uint8, uint8,
) )
from aesara.tensor.type import fscalar, imatrix, iscalar, matrix from aesara.tensor.type import fscalar, imatrix, iscalar, matrix
from tests.link.test_link import make_function
def test_mul_add_true(): def test_mul_add_true():
x, y, z = floats("xyz") x, y, z = floats("xyz")
e = mul(add(x, y), true_div(x, y)) e = mul(add(x, y), true_div(x, y))
g = FunctionGraph([x, y], [e]) g = FunctionGraph([x, y], [e])
fn = DualLinker().accept(g).make_function() fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
...@@ -121,7 +122,7 @@ class TestComposite: ...@@ -121,7 +122,7 @@ class TestComposite:
c = C.make_node(x, y) c = C.make_node(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = FunctionGraph([x, y], [c.out]) g = FunctionGraph([x, y], [c.out])
fn = DualLinker().accept(g).make_function() fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
def test_flatten(self): def test_flatten(self):
...@@ -144,7 +145,7 @@ class TestComposite: ...@@ -144,7 +145,7 @@ class TestComposite:
assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0)) assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = FunctionGraph([x, y], [c.out]) g = FunctionGraph([x, y], [c.out])
fn = DualLinker().accept(g).make_function() fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self): def test_many_outputs(self):
...@@ -156,7 +157,7 @@ class TestComposite: ...@@ -156,7 +157,7 @@ class TestComposite:
c = C.make_node(x, y, z) c = C.make_node(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0)) # print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
g = FunctionGraph([x, y, z], c.outputs) g = FunctionGraph([x, y, z], c.outputs)
fn = DualLinker().accept(g).make_function() fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
def test_composite_printing(self): def test_composite_printing(self):
...@@ -172,7 +173,7 @@ class TestComposite: ...@@ -172,7 +173,7 @@ class TestComposite:
C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7]) C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7])
c = C.make_node(x, y, z) c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs) g = FunctionGraph([x, y, z], c.outputs)
DualLinker().accept(g).make_function() make_function(DualLinker().accept(g))
assert str(g) == ( assert str(g) == (
"FunctionGraph(*1 -> Composite{((i0 + i1) + i2)," "FunctionGraph(*1 -> Composite{((i0 + i1) + i2),"
...@@ -206,71 +207,71 @@ class TestComposite: ...@@ -206,71 +207,71 @@ class TestComposite:
class TestLogical: class TestLogical:
def test_gt(self): def test_gt(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x > y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x > y])))
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a > b) assert fn(a, b) == (a > b)
def test_lt(self): def test_lt(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x < y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x < y])))
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a < b) assert fn(a, b) == (a < b)
def test_le(self): def test_le(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x <= y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x <= y])))
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a <= b) assert fn(a, b) == (a <= b)
def test_ge(self): def test_ge(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x >= y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x >= y])))
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a >= b) assert fn(a, b) == (a >= b)
def test_eq(self): def test_eq(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [eq(x, y)])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [eq(x, y)])))
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a == b) assert fn(a, b) == (a == b)
def test_neq(self): def test_neq(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])))
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a != b) assert fn(a, b) == (a != b)
def test_or(self): def test_or(self):
x, y, z = ints("xyz") x, y, z = ints("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x | y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x | y])))
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
assert fn(a, b) == (a | b), (a, b) assert fn(a, b) == (a | b), (a, b)
def test_xor(self): def test_xor(self):
x, y, z = ints("xyz") x, y, z = ints("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x ^ y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x ^ y])))
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
assert fn(a, b) == (a ^ b), (a, b) assert fn(a, b) == (a ^ b), (a, b)
def test_and(self): def test_and(self):
x, y, z = ints("xyz") x, y, z = ints("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [and_(x, y)])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [and_(x, y)])))
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
assert fn(a, b) == (a & b), (a, b) assert fn(a, b) == (a & b), (a, b)
x, y, z = ints("xyz") x, y, z = ints("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [x & y])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [x & y])))
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
assert fn(a, b) == (a & b), (a, b) assert fn(a, b) == (a & b), (a, b)
def test_not(self): def test_not(self):
x, y, z = ints("xyz") x, y, z = ints("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [invert(x)])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [invert(x)])))
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
assert fn(a, b) == ~a, (a,) assert fn(a, b) == ~a, (a,)
x, y, z = ints("xyz") x, y, z = ints("xyz")
fn = DualLinker().accept(FunctionGraph([x, y], [~x])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [~x])))
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
assert fn(a, b) == ~a, (a,) assert fn(a, b) == ~a, (a,)
......
import pytest import pytest
import aesara import aesara
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.basic import floats from aesara.scalar.basic import floats
from aesara.scalar.basic_sympy import SymPyCCode from aesara.scalar.basic_sympy import SymPyCCode
from tests.link.test_link import make_function
sympy = pytest.importorskip("sympy") sympy = pytest.importorskip("sympy")
...@@ -18,8 +21,8 @@ xt, yt = floats("xy") ...@@ -18,8 +21,8 @@ xt, yt = floats("xy")
def test_SymPyCCode(): def test_SymPyCCode():
op = SymPyCCode([xs, ys], xs + ys) op = SymPyCCode([xs, ys], xs + ys)
e = op(xt, yt) e = op(xt, yt)
g = aesara.graph.fg.FunctionGraph([xt, yt], [e]) g = FunctionGraph([xt, yt], [e])
fn = aesara.link.c.basic.CLinker().accept(g).make_function() fn = make_function(CLinker().accept(g))
assert fn(1.0, 2.0) == 3.0 assert fn(1.0, 2.0) == 3.0
......
...@@ -7,6 +7,7 @@ from aesara.compile.mode import Mode ...@@ -7,6 +7,7 @@ from aesara.compile.mode import Mode
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau
from tests.link.test_link import make_function
def test_gammainc_python(): def test_gammainc_python():
...@@ -21,7 +22,7 @@ def test_gammainc_nan_c(): ...@@ -21,7 +22,7 @@ def test_gammainc_nan_c():
x1 = at.dscalar() x1 = at.dscalar()
x2 = at.dscalar() x2 = at.dscalar()
y = gammainc(x1, x2) y = gammainc(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function() test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isnan(test_func(-1, 1)) assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
...@@ -39,7 +40,7 @@ def test_gammaincc_nan_c(): ...@@ -39,7 +40,7 @@ def test_gammaincc_nan_c():
x1 = at.dscalar() x1 = at.dscalar()
x2 = at.dscalar() x2 = at.dscalar()
y = gammaincc(x1, x2) y = gammaincc(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function() test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isnan(test_func(-1, 1)) assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
...@@ -49,7 +50,7 @@ def test_gammal_nan_c(): ...@@ -49,7 +50,7 @@ def test_gammal_nan_c():
x1 = at.dscalar() x1 = at.dscalar()
x2 = at.dscalar() x2 = at.dscalar()
y = gammal(x1, x2) y = gammal(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function() test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isnan(test_func(-1, 1)) assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
...@@ -59,7 +60,7 @@ def test_gammau_nan_c(): ...@@ -59,7 +60,7 @@ def test_gammau_nan_c():
x1 = at.dscalar() x1 = at.dscalar()
x2 = at.dscalar() x2 = at.dscalar()
y = gammau(x1, x2) y = gammau(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function() test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isnan(test_func(-1, 1)) assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1)) assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
......
...@@ -29,6 +29,7 @@ from aesara.tensor.type import ( ...@@ -29,6 +29,7 @@ from aesara.tensor.type import (
vectors, vectors,
) )
from tests import unittest_tools from tests import unittest_tools
from tests.link.test_link import make_function
from tests.tensor.test_math import reduce_bitwise_and from tests.tensor.test_math import reduce_bitwise_and
...@@ -181,7 +182,7 @@ class TestBroadcast: ...@@ -181,7 +182,7 @@ class TestBroadcast:
x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x")
y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y")
e = op(aes.add)(x, y) e = op(aes.add)(x, y)
f = copy(linker).accept(FunctionGraph([x, y], [e])).make_function() f = make_function(copy(linker).accept(FunctionGraph([x, y], [e])))
xv = rand_val(xsh) xv = rand_val(xsh)
yv = rand_val(ysh) yv = rand_val(ysh)
zv = xv + yv zv = xv + yv
...@@ -194,11 +195,7 @@ class TestBroadcast: ...@@ -194,11 +195,7 @@ class TestBroadcast:
x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x")
y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y")
e = op(aes.add)(x, y) e = op(aes.add)(x, y)
f = ( f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
copy(linker)
.accept(FunctionGraph([x, y], [e.shape]))
.make_function()
)
assert tuple(f(xv, yv)) == tuple(zv.shape) assert tuple(f(xv, yv)) == tuple(zv.shape)
def with_linker_inplace(self, linker, op, type, rand_val): def with_linker_inplace(self, linker, op, type, rand_val):
...@@ -215,7 +212,7 @@ class TestBroadcast: ...@@ -215,7 +212,7 @@ class TestBroadcast:
x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x")
y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y")
e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y) e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y)
f = copy(linker).accept(FunctionGraph([x, y], [e])).make_function() f = make_function(copy(linker).accept(FunctionGraph([x, y], [e])))
xv = rand_val(xsh) xv = rand_val(xsh)
yv = rand_val(ysh) yv = rand_val(ysh)
zv = xv + yv zv = xv + yv
...@@ -229,11 +226,7 @@ class TestBroadcast: ...@@ -229,11 +226,7 @@ class TestBroadcast:
x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x")
y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y")
e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y) e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y)
f = ( f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
copy(linker)
.accept(FunctionGraph([x, y], [e.shape]))
.make_function()
)
xv = rand_val(xsh) xv = rand_val(xsh)
yv = rand_val(ysh) yv = rand_val(ysh)
zv = xv + yv zv = xv + yv
...@@ -273,7 +266,7 @@ class TestBroadcast: ...@@ -273,7 +266,7 @@ class TestBroadcast:
x = t(aesara.config.floatX, (False, False))("x") x = t(aesara.config.floatX, (False, False))("x")
y = t(aesara.config.floatX, (True, True))("y") y = t(aesara.config.floatX, (True, True))("y")
e = op(aes.Second(aes.transfer_type(0)), {0: 0})(x, y) e = op(aes.Second(aes.transfer_type(0)), {0: 0})(x, y)
f = linker().accept(FunctionGraph([x, y], [e])).make_function() f = make_function(linker().accept(FunctionGraph([x, y], [e])))
xv = rval((5, 5)) xv = rval((5, 5))
yv = rval((1, 1)) yv = rval((1, 1))
f(xv, yv) f(xv, yv)
...@@ -304,7 +297,7 @@ class TestBroadcast: ...@@ -304,7 +297,7 @@ class TestBroadcast:
x = t(aesara.config.floatX, (False,) * 5)("x") x = t(aesara.config.floatX, (False,) * 5)("x")
y = t(aesara.config.floatX, (False,) * 5)("y") y = t(aesara.config.floatX, (False,) * 5)("y")
e = op(aes.add)(x, y) e = op(aes.add)(x, y)
f = linker().accept(FunctionGraph([x, y], [e])).make_function() f = make_function(linker().accept(FunctionGraph([x, y], [e])))
xv = rval((2, 2, 2, 2, 2)) xv = rval((2, 2, 2, 2, 2))
yv = rval((2, 2, 2, 2, 2)).transpose(4, 0, 3, 1, 2) yv = rval((2, 2, 2, 2, 2)).transpose(4, 0, 3, 1, 2)
zv = xv + yv zv = xv + yv
...@@ -322,7 +315,7 @@ class TestBroadcast: ...@@ -322,7 +315,7 @@ class TestBroadcast:
): ):
x = t(aesara.config.floatX, (False,) * 2)("x") x = t(aesara.config.floatX, (False,) * 2)("x")
e = op(aes.add)(x, x) e = op(aes.add)(x, x)
f = linker().accept(FunctionGraph([x], [e])).make_function() f = make_function(linker().accept(FunctionGraph([x], [e])))
xv = rval((2, 2)) xv = rval((2, 2))
zv = xv + xv zv = xv + xv
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
......
...@@ -139,6 +139,7 @@ from aesara.tensor.type import ( ...@@ -139,6 +139,7 @@ from aesara.tensor.type import (
) )
from aesara.tensor.type_other import NoneConst from aesara.tensor.type_other import NoneConst
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.test_link import make_function
from tests.tensor.utils import ( from tests.tensor.utils import (
_bad_build_broadcast_binary_normal, _bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal, _bad_runtime_broadcast_binary_normal,
...@@ -2387,7 +2388,7 @@ def test_divmod(): ...@@ -2387,7 +2388,7 @@ def test_divmod():
# Confirm that divmod is equivalent to the python version. # Confirm that divmod is equivalent to the python version.
x, y = fscalars("xy") x, y = fscalars("xy")
d, r = divmod(x, y) d, r = divmod(x, y)
fn = DualLinker().accept(FunctionGraph([x, y], [d, r])).make_function() fn = make_function(DualLinker().accept(FunctionGraph([x, y], [d, r])))
for a, b in ( for a, b in (
(0, 1), (0, 1),
(1, 1), (1, 1),
......
...@@ -21,6 +21,7 @@ from aesara.tensor.opt_uncanonicalize import ( ...@@ -21,6 +21,7 @@ from aesara.tensor.opt_uncanonicalize import (
) )
from aesara.tensor.shape import reshape from aesara.tensor.shape import reshape
from aesara.tensor.type import dtensor4, iscalar, matrix, tensor, vector from aesara.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function
class TestMaxAndArgmax: class TestMaxAndArgmax:
...@@ -165,7 +166,7 @@ def test_local_dimshuffle_alloc(): ...@@ -165,7 +166,7 @@ def test_local_dimshuffle_alloc():
l = PerformLinker() l = PerformLinker()
l.accept(g) l.accept(g)
f = l.make_function() f = make_function(l)
assert f([3, 4]).ndim == 4 assert f([3, 4]).ndim == 4
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论