提交 c2092862 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Function.fn to Function.vm

上级 d39b852a
......@@ -9,7 +9,7 @@ import logging
import time
import warnings
from itertools import chain
from typing import List, Optional, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import numpy as np
......@@ -34,6 +34,10 @@ from aesara.link.basic import Container
from aesara.link.utils import raise_with_op
if TYPE_CHECKING:
from aesara.link.vm import VM
_logger = logging.getLogger("aesara.compile.function.types")
......@@ -271,42 +275,45 @@ DUPLICATE = object()
class Function:
"""
Type of the functions returned by aesara.function or
aesara.FunctionMaker.create.
r"""A class that wraps the execution of a `VM` making it easier for use as a "function".
`Function` is the callable object that does computation. It has the storage
of inputs and outputs, performs the packing and unpacking of inputs and
return values. It implements the square-bracket indexing so that you can
look up the value of a symbolic node.
Functions are copyable via {{{fn.copy()}}} and {{{copy.copy(fn)}}}.
Functions are copyable via `Function.copy` and the `copy.copy` interface.
When a function is copied, this instance is duplicated. Contrast with
self.maker (instance of `FunctionMaker`) that is shared between copies.
The meaning of copying a function is that the containers and their current
values will all be duplicated. This requires that mutable inputs be
copied, whereas immutable inputs may be shared between copies.
A Function instance is hashable, on the basis of its memory
address (its id).
A Function instance is hashable, on the basis of its memory address (its
id).
A Function instance is only equal to itself.
A Function instance may be serialized using the `pickle` or
`cPickle` modules. This will save all default inputs, the graph,
and WRITEME to the pickle file.
A Function instance have a ``trust_input`` field that default to
False. When True, we don't do extra check of the input to give
better error message. In some case, python code will still return
the good results if you pass a python or numpy scalar instead of a
numpy tensor. C code should raise an error if you pass an object
of the wrong type.
A `Function` instance has a `Function.trust_input` field that defaults to
``False``. When ``True``, the `Function` will skip all checks on the
inputs.
Attributes
----------
finder
Dictionary mapping several kinds of things to containers.
We set an entry in finder for:
- the index of the input
- the variable instance the input is based on
- the name of the input
All entries map to the container or to DUPLICATE if an ambiguity
is detected.
inv_finder
Reverse lookup of `finder`. It maps containers to `SymbolicInput`\s.
"""
......@@ -321,111 +328,59 @@ class Function:
If the value is 'raise', then an AliasedMemoryError will be raised
if aliased storage is detected during pickle.dump.
"""
input_storage = None
"""
List of Container instances.
"""
output_storage = None
"""
List of Container instances.
"""
indices = None
"""
List of (SymbolicInput, indices, [SymbolicInput,...]),
one tuple for each input.
The first tuple element is the SymbolicInput object for the corresponding
function input.
The second and third tuple elements are used only by Kits, which
are deprecated.
"""
defaults = None
"""
List of 3-tuples, one 3-tuple for each input.
Tuple element 0: Bool: Is this input required at each function call?
Tuple element 1: Bool: Should this inputs value be reverted after
each call?
Tuple element 2: Any: The value associated with this input.
"""
unpack_single = None
"""
Bool: for outputs lists of length 1, should the 0'th element be
returned directly?
"""
return_none = None
"""
Bool: whether the function should return None or not.
"""
maker = None
"""
FunctionMaker instance.
"""
fn = None
"""
A function that evaluates the graph. Typically a linker's make_thunk method
created this function.
"""
finder = None
"""
Dictionary mapping several kinds of things to containers.
We set an entry in finder for:
- the index of the input
- the variable instance the input is based on
- the name of the input
All entries map to the container or to DUPLICATE if an ambiguity
is detected.
"""
inv_finder = None
"""
Dict. Reverse lookup of `finder`.
It maps container -> SymbolicInput
"""
def __init__(
self,
fn,
vm: "VM",
input_storage,
output_storage,
indices,
outputs,
defaults,
unpack_single,
return_none,
unpack_single: bool,
return_none: bool,
output_keys,
maker,
name=None,
maker: "FunctionMaker",
name: Optional[str] = None,
):
self.fn = fn
"""
Parameters
----------
vm
A `VM` instance that evaluates the graph when called.
input_storage
List of storage cells for each input.
output_storage
List of storage cells for each output.
indices
List of ``(SymbolicInput, indices, [SymbolicInput,...])``, one
tuple for each input. The first tuple element is the `SymbolicInput`
object for the corresponding function input. The second and third
tuple elements are used only by Kits, which are deprecated.
outputs
TODO
defaults
List of 3-tuples, one 3-tuple for each input.
Tuple element 0: ``bool``. Is this input required at each function
call?
Tuple element 1: ``bool``. Should this inputs value be reverted
after each call?
Tuple element 2: ``Any``. The value associated with this input.
unpack_single
For outputs lists of length 1, should the 0'th element be
returned directly?
return_none
Whether the function should return ``None`` or not.
output_keys
TODO
maker
The `FunctionMaker` that created this instance.
name
A string name.
"""
# TODO: Rename to `vm`
self.vm = vm
self.input_storage = input_storage
self.output_storage = output_storage
self.indices = indices
......@@ -441,7 +396,7 @@ class Function:
self.output_keys = output_keys
# See if we have any mutable / borrow inputs
# TODO: this only need to be set if there is more then 1 input
# TODO: this only need to be set if there is more than one input
self._check_for_aliased_inputs = False
for i in maker.inputs:
# If the input is a shared variable, the memory region is
......@@ -575,7 +530,7 @@ class Function:
# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage)
# This is used only when `fn.need_update_inputs` is `False`, because
# This is used only when `vm.need_update_inputs` is `False`, because
# we're using one of the VM objects and it is putting updates back into
# the input containers all by itself.
self.n_returned_outputs = len(self.output_storage) - sum(
......@@ -752,7 +707,7 @@ class Function:
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map
storage_map = self.vm.storage_map
new_storage_map = {}
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
......@@ -1015,24 +970,24 @@ class Function:
t0_fn = time.time()
try:
outputs = (
self.fn()
self.vm()
if output_subset is None
else self.fn(output_subset=output_subset)
else self.vm(output_subset=output_subset)
)
except Exception:
restore_defaults()
if hasattr(self.fn, "position_of_error"):
if hasattr(self.vm, "position_of_error"):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
# done by raise_with_op is not implemented in C.
thunk = None
if hasattr(self.fn, "thunks"):
thunk = self.fn.thunks[self.fn.position_of_error]
if hasattr(self.vm, "thunks"):
thunk = self.vm.thunks[self.vm.position_of_error]
raise_with_op(
self.maker.fgraph,
node=self.fn.nodes[self.fn.position_of_error],
node=self.vm.nodes[self.vm.position_of_error],
thunk=thunk,
storage_map=getattr(self.fn, "storage_map", None),
storage_map=getattr(self.vm, "storage_map", None),
)
else:
# old-style linkers raise their own exceptions
......@@ -1056,7 +1011,7 @@ class Function:
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.fn, "allow_gc", False):
if getattr(self.vm, "allow_gc", False):
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs
......@@ -1068,7 +1023,7 @@ class Function:
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.fn, "need_update_inputs", True):
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, self.input_storage))
......@@ -1092,8 +1047,8 @@ class Function:
if profile:
profile.fct_callcount += 1
profile.fct_call_time += dt_call
if hasattr(self.fn, "update_profile"):
self.fn.update_profile(profile)
if hasattr(self.vm, "update_profile"):
self.vm.update_profile(profile)
if profile.ignore_first_call:
profile.reset()
profile.ignore_first_call = False
......@@ -1137,10 +1092,10 @@ class Function:
"""
# 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
if not getattr(self.fn, "allow_gc", True):
for key in self.fn.storage_map:
if not getattr(self.vm, "allow_gc", True):
for key in self.vm.storage_map:
if not isinstance(key, Constant):
self.fn.storage_map[key][0] = None
self.vm.storage_map[key][0] = None
for node in self.nodes_with_inner_function:
if hasattr(node.fn, "free"):
......
......@@ -217,7 +217,7 @@ class ProfileStats:
#
vm_call_time = 0.0
# Total time spent in Function.fn.__call__
# Total time spent in Function.vm.__call__
#
apply_time = None
......@@ -781,7 +781,7 @@ class ProfileStats:
)
if self.fct_call_time > 0:
print(
f" Time in Function.fn.__call__: {self.vm_call_time}s ({100 * self.vm_call_time / self.fct_call_time:.3f}%)",
f" Time in Function.vm.__call__: {self.vm_call_time}s ({100 * self.vm_call_time / self.fct_call_time:.3f}%)",
file=file,
)
local_time = sum(self.apply_time.values())
......
......@@ -1139,9 +1139,9 @@ def clone_replace(
Parameters
----------
output : Aesara Variables (or Aesara expressions)
output
Aesara expression that represents the computational graph.
replace : dict
replace
Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds
Keywords to `rebuild_collect_shared`.
......
......@@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order=
if any(x.op.__class__.__name__ == "Gemm" for x in f.maker.fgraph.toposort()):
c_impl = [
hasattr(thunk, "cthunk")
for node, thunk in zip(f.fn.nodes, f.fn.thunks)
for node, thunk in zip(f.vm.nodes, f.vm.thunks)
if node.op.__class__.__name__ == "Gemm"
]
assert len(c_impl) == 1
......
......@@ -222,7 +222,7 @@ def debugprint(
results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
if print_storage:
smap.extend([obj.fn.storage_map for item in obj.maker.fgraph.outputs])
smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs])
else:
smap.extend([None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort()
......
......@@ -75,7 +75,7 @@ def multMatVect(v, A, m1, B, m2):
f.input_storage[3].storage[0] = B
f.input_storage[4].storage[0] = v[3:]
f.input_storage[5].storage[0] = m2
f.fn()
f.vm()
r = f.output_storage[0].storage[0]
return r
......@@ -829,7 +829,7 @@ class MRG_RandomStream:
v = rval[i - 1]
f.input_storage[1].storage[0] = v[:3]
f.input_storage[4].storage[0] = v[3:]
f.fn()
f.vm()
rval[i] = f.output_storage[0].storage[0]
if inc_rstate:
......
......@@ -1594,8 +1594,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
from aesara.scan.utils import InnerFunctionError
# TODO: Extract `Capsule` object and use that
# c_thunk = getattr(self.fn.fn.thunks[0], "cthunk", None)
# if len(self.fn.fn.thunks) == 1 and c_thunk:
# c_thunk = getattr(self.fn.vm.thunks[0], "cthunk", None)
# if len(self.fn.vm.thunks) == 1 and c_thunk:
# thunk_capsule = c_thunk.cthunk
# # We need to perform the following after calling
# # the thunk function:
......@@ -1633,20 +1633,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outputs,
outer_output_dtypes,
outer_output_ndims,
self.fn.fn,
self.fn.vm,
)
except InnerFunctionError as exc:
exc_type = type(exc.args[0])
exc_value = exc.args[0]
exc_trace = exc.args[1]
if hasattr(self.fn.fn, "position_of_error") and hasattr(
self.fn.fn, "thunks"
if hasattr(self.fn.vm, "position_of_error") and hasattr(
self.fn.vm, "thunks"
):
raise_with_op(
self.fn.maker.fgraph,
self.fn.fn.nodes[self.fn.fn.position_of_error],
self.fn.fn.thunks[self.fn.fn.position_of_error],
self.fn.vm.nodes[self.fn.vm.position_of_error],
self.fn.vm.thunks[self.fn.vm.position_of_error],
exc_info=(exc_type, exc_value, exc_trace),
)
else:
......@@ -1661,8 +1661,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
profile.callcount += 1
profile.nbsteps += n_steps
profile.call_time += t_call
if hasattr(self.fn.fn, "update_profile"):
self.fn.fn.update_profile(profile)
if hasattr(self.fn.vm, "update_profile"):
self.fn.vm.update_profile(profile)
except (ImportError, MissingGXX):
p = self.perform
......@@ -1795,7 +1795,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage = self.fn.output_storage
old_inner_output_storage = [None] * len(inner_output_storage)
old_inner_output_data = [None] * len(inner_output_storage)
fn = self.fn.fn
vm = self.fn.vm
offset = (
info.n_seqs
+ sum(
......@@ -1938,18 +1938,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
t0_fn = time.time()
try:
fn()
vm()
except Exception:
if hasattr(fn, "position_of_error"):
if hasattr(vm, "position_of_error"):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
# done by raise_with_op is not implemented in C.
if hasattr(fn, "thunks"):
if hasattr(vm, "thunks"):
# For the CVM
raise_with_op(
self.fn.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error],
vm.nodes[vm.position_of_error],
vm.thunks[vm.position_of_error],
)
else:
# For the c linker
......@@ -1957,7 +1957,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# temps values So for now, we just don't print
# the extra shapes/strides info
raise_with_op(
self.fn.maker.fgraph, fn.nodes[fn.position_of_error]
self.fn.maker.fgraph, vm.nodes[vm.position_of_error]
)
else:
# old-style linkers raise their own exceptions
......@@ -2200,8 +2200,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
profile.nbsteps += n_steps
profile.call_time += t_call
profile.vm_call_time += t_fn
if hasattr(self.fn.fn, "update_profile"):
self.fn.fn.update_profile(profile)
if hasattr(self.fn.vm, "update_profile"):
self.fn.vm.update_profile(profile)
self.t_call = t_call
self.t_fn = t_fn
......
......@@ -751,6 +751,8 @@ def add_nitsot_outputs(
new_outputs_inner,
) -> Tuple[Apply, Dict[Variable, Variable]]:
assert isinstance(old_scan_node.op, Scan)
nb_new_outs = len(new_outputs_inner)
# Create the initial values for the new nitsot outputs
......
......@@ -141,8 +141,8 @@ with
Also, for small Aesara functions, you can remove more Python overhead by
making an Aesara function that does not take any input. You can use shared
variables to achieve this. Then you can call it like this: ``f.fn()`` or
``f.fn(n_calls=N)`` to speed it up. In the last case, only the last
variables to achieve this. Then you can call it like this: ``f.vm()`` or
``f.vm(n_calls=N)`` to speed it up. In the last case, only the last
function output (out of N calls) is returned.
You can also use the ``C`` linker that will put all nodes in the same C
......
......@@ -140,9 +140,9 @@ Running the above code generates the following error message:
File "test1.py", line 31, in <module>
f(np.random.random((5, 10)))
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 605, in __call__
self.fn.thunks[self.fn.position_of_error])
self.vm.thunks[self.vm.position_of_error])
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 595, in __call__
outputs = self.fn()
outputs = self.vm()
ValueError: Shape mismatch: x has 10 cols (and 5 rows) but y has 20 rows (and 10 cols)
Apply node that caused the error: Dot22(x, DimShuffle{1,0}.0)
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))]
......
......@@ -52,8 +52,8 @@ function. aesara.function() has an optional parameter ``name`` that
defaults to None. Change it to something else to help you profile many
Aesara functions. In that section, we also see the number of times the
function was called (1) and the total time spent in all those
calls. The time spent in Function.fn.__call__ and in thunks is useful
to understand Aesara overhead.
calls. The time spent in :meth:`Function.vm.__call__` and in thunks is useful
to understand Aesara's overhead.
Also, we see the time spent in the two parts of the compilation
process: optimization (modify the graph to make it more stable/faster)
......
......@@ -2,7 +2,7 @@ Function profiling
==================
Message: None
Time in 1 calls to Function.__call__: 5.698204e-05s
Time in Function.fn.__call__: 1.192093e-05s (20.921%)
Time in Function.vm.__call__: 1.192093e-05s (20.921%)
Time in thunks: 6.198883e-06s (10.879%)
Total compile time: 3.642474e+00s
Aesara Optimizer time: 7.326508e-02s
......
......@@ -346,8 +346,8 @@ class TestFunction:
cpy = ori.copy(share_memory=True)
# Test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
storage_map_ori = ori.vm.storage_map
storage_map_cpy = cpy.vm.storage_map
fgraph_cpy = cpy.maker.fgraph
# Assert intermediate and Constants storages are shared.
......@@ -424,11 +424,11 @@ class TestFunction:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys()
for key in cpy.fn.storage_map:
for key in cpy.vm.storage_map:
if key.name in names:
assert (
map_SV[key.name].container.storage[0]
== cpy.fn.storage_map[key][0]
== cpy.vm.storage_map[key][0]
)
second_time = True
......@@ -688,18 +688,18 @@ class TestFunction:
x = vector("x")
func = function([x], x + 1)
func.fn.allow_gc = False
func.vm.allow_gc = False
func([1])
check_list = []
for key, val in func.fn.storage_map.items():
for key, val in func.vm.storage_map.items():
if not isinstance(key, Constant):
check_list.append(val)
assert any(val[0] for val in check_list)
func.free()
for key, val in func.fn.storage_map.items():
for key, val in func.vm.storage_map.items():
if not isinstance(key, Constant):
assert val[0] is None
......
......@@ -3505,7 +3505,7 @@ def test_config_options_parallel():
with config.change_flags(numba__vectorize_target="parallel"):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert numba_mul_fn.targetoptions["parallel"] is True
......@@ -3514,7 +3514,7 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert numba_mul_fn.targetoptions["fastmath"] is True
......@@ -3523,12 +3523,12 @@ def test_config_options_cached():
with config.change_flags(numba__cache=True):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert not isinstance(
numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache
)
with config.change_flags(numba__cache=False):
aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"]
numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache)
......@@ -52,11 +52,11 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
assert np.array_equal(numba_res, numpy_res)
# FYI: To test the Numba JITed function directly, use `aesara_numba_fn.fn.jit_fn`
# FYI: To test the Numba JITed function directly, use `aesara_numba_fn.vm.jit_fn`
numpy_timer = timeit.Timer("numpy_fn(*input_vals)", "pass", globals=locals())
numba_timer = timeit.Timer(
"aesara_numba_fn.fn.jit_fn(*input_vals)", "pass", globals=locals()
"aesara_numba_fn.vm.jit_fn(*input_vals)", "pass", globals=locals()
)
# c_timer = timeit.Timer("aesara_c_fn(*input_vals)", "pass", globals=locals())
......
......@@ -86,7 +86,7 @@ def test_use_c_thunks():
),
)
assert np.array_equal(a * b, f(a, b))
assert any(hasattr(t, "cthunk") for t in f.fn.thunks) == use_c_thunks
assert any(hasattr(t, "cthunk") for t in f.vm.thunks) == use_c_thunks
@pytest.mark.skipif(
......@@ -215,9 +215,9 @@ def test_partial_function(linker):
if linker == "cvm":
from aesara.link.c.cvm import CVM
assert isinstance(f.fn, CVM)
assert isinstance(f.vm, CVM)
else:
assert isinstance(f.fn, Stack)
assert isinstance(f.vm, Stack)
assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
......@@ -277,17 +277,17 @@ def test_allow_gc_cvm():
f([1])
n = list(f.maker.fgraph.apply_nodes)[0].outputs[0]
assert f.fn.storage_map[n][0] is None
assert f.fn.allow_gc is True
assert f.vm.storage_map[n][0] is None
assert f.vm.allow_gc is True
f.fn.allow_gc = False
assert f.fn.allow_gc is False
f.vm.allow_gc = False
assert f.vm.allow_gc is False
f([1])
assert f.fn.storage_map[n][0] is not None
f.fn.allow_gc = True
assert f.fn.allow_gc is True
assert f.vm.storage_map[n][0] is not None
f.vm.allow_gc = True
assert f.vm.allow_gc is True
f([1])
assert f.fn.storage_map[n][0] is None
assert f.vm.storage_map[n][0] is None
class RunOnce(Op):
......@@ -334,7 +334,7 @@ def test_reallocation():
f = function([x, y], z, name="test_reduce_memory", mode=m)
output = f(1, 2)
assert output
storage_map = f.fn.storage_map
storage_map = f.vm.storage_map
def check_storage(storage_map):
for i in storage_map:
......@@ -365,8 +365,8 @@ def test_no_recycling():
mode = Mode(optimizer="fast_compile", linker=lnk)
f = function([x], x + 1, mode=mode)
f2 = function([x], (x + 1) * 2, mode=mode)
m1 = f.fn.thunks[0].thunk.module
m2 = f2.fn.thunks[0].thunk.module
m1 = f.vm.thunks[0].thunk.module
m2 = f2.vm.thunks[0].thunk.module
assert m1 is m2
......@@ -381,7 +381,7 @@ def test_VMLinker_make_vm_cvm():
linker = VMLinker(allow_gc=False, use_cloop=True)
f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, CVM)
assert isinstance(f.vm, CVM)
def test_VMLinker_make_vm_no_cvm():
......@@ -405,7 +405,7 @@ def test_VMLinker_make_vm_no_cvm():
import aesara.link.c.cvm
f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, Loop)
assert isinstance(f.vm, Loop)
def test_VMLinker_exception():
......
......@@ -916,7 +916,7 @@ def test_multMatVect():
r_a1 = rng_mrg.matVecModM(A1, s1, m1)
r_a2 = rng_mrg.matVecModM(A2, s2, m2)
f0.fn()
f0.vm()
r_b = f0.output_storage[0].value
assert np.allclose(r_a1, r_b[:3])
......
......@@ -2702,8 +2702,8 @@ def test_profile_info():
assert profile.callcount == 0
assert profile.nbsteps == 0
assert profile.call_time == 0.0
assert fn.fn.call_times == [0.0]
assert fn.fn.call_counts == [0]
assert fn.vm.call_times == [0.0]
assert fn.vm.call_counts == [0]
z_fn = function([], z)
......@@ -2716,8 +2716,8 @@ def test_profile_info():
# Confirm that `VM.update_profile` was called
assert profile.apply_time
assert fn.fn.call_times == [0.0]
assert fn.fn.call_counts == [0]
assert fn.vm.call_times == [0.0]
assert fn.vm.call_counts == [0]
class TestExamples:
......
......@@ -616,7 +616,7 @@ class TestConv2D(utt.InferShapeTester):
)
aesara_conv = aesara.function([], output, mode=mode)
t1 = time.time()
aesara_conv.fn(n_calls=n_calls)
aesara_conv.vm(n_calls=n_calls)
t2 = time.time()
print(t2 - t1, end=" ")
print()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论