提交 c2092862 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename Function.fn to Function.vm

上级 d39b852a
...@@ -217,7 +217,7 @@ class ProfileStats: ...@@ -217,7 +217,7 @@ class ProfileStats:
# #
vm_call_time = 0.0 vm_call_time = 0.0
# Total time spent in Function.fn.__call__ # Total time spent in Function.vm.__call__
# #
apply_time = None apply_time = None
...@@ -781,7 +781,7 @@ class ProfileStats: ...@@ -781,7 +781,7 @@ class ProfileStats:
) )
if self.fct_call_time > 0: if self.fct_call_time > 0:
print( print(
f" Time in Function.fn.__call__: {self.vm_call_time}s ({100 * self.vm_call_time / self.fct_call_time:.3f}%)", f" Time in Function.vm.__call__: {self.vm_call_time}s ({100 * self.vm_call_time / self.fct_call_time:.3f}%)",
file=file, file=file,
) )
local_time = sum(self.apply_time.values()) local_time = sum(self.apply_time.values())
......
...@@ -1139,9 +1139,9 @@ def clone_replace( ...@@ -1139,9 +1139,9 @@ def clone_replace(
Parameters Parameters
---------- ----------
output : Aesara Variables (or Aesara expressions) output
Aesara expression that represents the computational graph. Aesara expression that represents the computational graph.
replace : dict replace
Dictionary describing which subgraphs should be replaced by what. Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds rebuild_kwds
Keywords to `rebuild_collect_shared`. Keywords to `rebuild_collect_shared`.
......
...@@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order= ...@@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order=
if any(x.op.__class__.__name__ == "Gemm" for x in f.maker.fgraph.toposort()): if any(x.op.__class__.__name__ == "Gemm" for x in f.maker.fgraph.toposort()):
c_impl = [ c_impl = [
hasattr(thunk, "cthunk") hasattr(thunk, "cthunk")
for node, thunk in zip(f.fn.nodes, f.fn.thunks) for node, thunk in zip(f.vm.nodes, f.vm.thunks)
if node.op.__class__.__name__ == "Gemm" if node.op.__class__.__name__ == "Gemm"
] ]
assert len(c_impl) == 1 assert len(c_impl) == 1
......
...@@ -222,7 +222,7 @@ def debugprint( ...@@ -222,7 +222,7 @@ def debugprint(
results_to_print.extend(obj.maker.fgraph.outputs) results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs]) profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
if print_storage: if print_storage:
smap.extend([obj.fn.storage_map for item in obj.maker.fgraph.outputs]) smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs])
else: else:
smap.extend([None for item in obj.maker.fgraph.outputs]) smap.extend([None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort() topo = obj.maker.fgraph.toposort()
......
...@@ -75,7 +75,7 @@ def multMatVect(v, A, m1, B, m2): ...@@ -75,7 +75,7 @@ def multMatVect(v, A, m1, B, m2):
f.input_storage[3].storage[0] = B f.input_storage[3].storage[0] = B
f.input_storage[4].storage[0] = v[3:] f.input_storage[4].storage[0] = v[3:]
f.input_storage[5].storage[0] = m2 f.input_storage[5].storage[0] = m2
f.fn() f.vm()
r = f.output_storage[0].storage[0] r = f.output_storage[0].storage[0]
return r return r
...@@ -829,7 +829,7 @@ class MRG_RandomStream: ...@@ -829,7 +829,7 @@ class MRG_RandomStream:
v = rval[i - 1] v = rval[i - 1]
f.input_storage[1].storage[0] = v[:3] f.input_storage[1].storage[0] = v[:3]
f.input_storage[4].storage[0] = v[3:] f.input_storage[4].storage[0] = v[3:]
f.fn() f.vm()
rval[i] = f.output_storage[0].storage[0] rval[i] = f.output_storage[0].storage[0]
if inc_rstate: if inc_rstate:
......
...@@ -1594,8 +1594,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1594,8 +1594,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
from aesara.scan.utils import InnerFunctionError from aesara.scan.utils import InnerFunctionError
# TODO: Extract `Capsule` object and use that # TODO: Extract `Capsule` object and use that
# c_thunk = getattr(self.fn.fn.thunks[0], "cthunk", None) # c_thunk = getattr(self.fn.vm.thunks[0], "cthunk", None)
# if len(self.fn.fn.thunks) == 1 and c_thunk: # if len(self.fn.vm.thunks) == 1 and c_thunk:
# thunk_capsule = c_thunk.cthunk # thunk_capsule = c_thunk.cthunk
# # We need to perform the following after calling # # We need to perform the following after calling
# # the thunk function: # # the thunk function:
...@@ -1633,20 +1633,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1633,20 +1633,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outputs, outputs,
outer_output_dtypes, outer_output_dtypes,
outer_output_ndims, outer_output_ndims,
self.fn.fn, self.fn.vm,
) )
except InnerFunctionError as exc: except InnerFunctionError as exc:
exc_type = type(exc.args[0]) exc_type = type(exc.args[0])
exc_value = exc.args[0] exc_value = exc.args[0]
exc_trace = exc.args[1] exc_trace = exc.args[1]
if hasattr(self.fn.fn, "position_of_error") and hasattr( if hasattr(self.fn.vm, "position_of_error") and hasattr(
self.fn.fn, "thunks" self.fn.vm, "thunks"
): ):
raise_with_op( raise_with_op(
self.fn.maker.fgraph, self.fn.maker.fgraph,
self.fn.fn.nodes[self.fn.fn.position_of_error], self.fn.vm.nodes[self.fn.vm.position_of_error],
self.fn.fn.thunks[self.fn.fn.position_of_error], self.fn.vm.thunks[self.fn.vm.position_of_error],
exc_info=(exc_type, exc_value, exc_trace), exc_info=(exc_type, exc_value, exc_trace),
) )
else: else:
...@@ -1661,8 +1661,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1661,8 +1661,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
profile.callcount += 1 profile.callcount += 1
profile.nbsteps += n_steps profile.nbsteps += n_steps
profile.call_time += t_call profile.call_time += t_call
if hasattr(self.fn.fn, "update_profile"): if hasattr(self.fn.vm, "update_profile"):
self.fn.fn.update_profile(profile) self.fn.vm.update_profile(profile)
except (ImportError, MissingGXX): except (ImportError, MissingGXX):
p = self.perform p = self.perform
...@@ -1795,7 +1795,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1795,7 +1795,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage = self.fn.output_storage inner_output_storage = self.fn.output_storage
old_inner_output_storage = [None] * len(inner_output_storage) old_inner_output_storage = [None] * len(inner_output_storage)
old_inner_output_data = [None] * len(inner_output_storage) old_inner_output_data = [None] * len(inner_output_storage)
fn = self.fn.fn vm = self.fn.vm
offset = ( offset = (
info.n_seqs info.n_seqs
+ sum( + sum(
...@@ -1938,18 +1938,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1938,18 +1938,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
t0_fn = time.time() t0_fn = time.time()
try: try:
fn() vm()
except Exception: except Exception:
if hasattr(fn, "position_of_error"): if hasattr(vm, "position_of_error"):
# this is a new vm-provided function or c linker # this is a new vm-provided function or c linker
# they need this because the exception manipulation # they need this because the exception manipulation
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
if hasattr(fn, "thunks"): if hasattr(vm, "thunks"):
# For the CVM # For the CVM
raise_with_op( raise_with_op(
self.fn.maker.fgraph, self.fn.maker.fgraph,
fn.nodes[fn.position_of_error], vm.nodes[vm.position_of_error],
fn.thunks[fn.position_of_error], vm.thunks[vm.position_of_error],
) )
else: else:
# For the c linker # For the c linker
...@@ -1957,7 +1957,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1957,7 +1957,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# temps values So for now, we just don't print # temps values So for now, we just don't print
# the extra shapes/strides info # the extra shapes/strides info
raise_with_op( raise_with_op(
self.fn.maker.fgraph, fn.nodes[fn.position_of_error] self.fn.maker.fgraph, vm.nodes[vm.position_of_error]
) )
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
...@@ -2200,8 +2200,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2200,8 +2200,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
profile.nbsteps += n_steps profile.nbsteps += n_steps
profile.call_time += t_call profile.call_time += t_call
profile.vm_call_time += t_fn profile.vm_call_time += t_fn
if hasattr(self.fn.fn, "update_profile"): if hasattr(self.fn.vm, "update_profile"):
self.fn.fn.update_profile(profile) self.fn.vm.update_profile(profile)
self.t_call = t_call self.t_call = t_call
self.t_fn = t_fn self.t_fn = t_fn
......
...@@ -751,6 +751,8 @@ def add_nitsot_outputs( ...@@ -751,6 +751,8 @@ def add_nitsot_outputs(
new_outputs_inner, new_outputs_inner,
) -> Tuple[Apply, Dict[Variable, Variable]]: ) -> Tuple[Apply, Dict[Variable, Variable]]:
assert isinstance(old_scan_node.op, Scan)
nb_new_outs = len(new_outputs_inner) nb_new_outs = len(new_outputs_inner)
# Create the initial values for the new nitsot outputs # Create the initial values for the new nitsot outputs
......
...@@ -141,8 +141,8 @@ with ...@@ -141,8 +141,8 @@ with
Also, for small Aesara functions, you can remove more Python overhead by Also, for small Aesara functions, you can remove more Python overhead by
making an Aesara function that does not take any input. You can use shared making an Aesara function that does not take any input. You can use shared
variables to achieve this. Then you can call it like this: ``f.fn()`` or variables to achieve this. Then you can call it like this: ``f.vm()`` or
``f.fn(n_calls=N)`` to speed it up. In the last case, only the last ``f.vm(n_calls=N)`` to speed it up. In the last case, only the last
function output (out of N calls) is returned. function output (out of N calls) is returned.
You can also use the ``C`` linker that will put all nodes in the same C You can also use the ``C`` linker that will put all nodes in the same C
......
...@@ -140,9 +140,9 @@ Running the above code generates the following error message: ...@@ -140,9 +140,9 @@ Running the above code generates the following error message:
File "test1.py", line 31, in <module> File "test1.py", line 31, in <module>
f(np.random.random((5, 10))) f(np.random.random((5, 10)))
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 605, in __call__ File "PATH_TO_AESARA/aesara/compile/function/types.py", line 605, in __call__
self.fn.thunks[self.fn.position_of_error]) self.vm.thunks[self.vm.position_of_error])
File "PATH_TO_AESARA/aesara/compile/function/types.py", line 595, in __call__ File "PATH_TO_AESARA/aesara/compile/function/types.py", line 595, in __call__
outputs = self.fn() outputs = self.vm()
ValueError: Shape mismatch: x has 10 cols (and 5 rows) but y has 20 rows (and 10 cols) ValueError: Shape mismatch: x has 10 cols (and 5 rows) but y has 20 rows (and 10 cols)
Apply node that caused the error: Dot22(x, DimShuffle{1,0}.0) Apply node that caused the error: Dot22(x, DimShuffle{1,0}.0)
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))] Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))]
......
...@@ -52,8 +52,8 @@ function. aesara.function() has an optional parameter ``name`` that ...@@ -52,8 +52,8 @@ function. aesara.function() has an optional parameter ``name`` that
defaults to None. Change it to something else to help you profile many defaults to None. Change it to something else to help you profile many
Aesara functions. In that section, we also see the number of times the Aesara functions. In that section, we also see the number of times the
function was called (1) and the total time spent in all those function was called (1) and the total time spent in all those
calls. The time spent in Function.fn.__call__ and in thunks is useful calls. The time spent in :meth:`Function.vm.__call__` and in thunks is useful
to understand Aesara overhead. to understand Aesara's overhead.
Also, we see the time spent in the two parts of the compilation Also, we see the time spent in the two parts of the compilation
process: optimization (modify the graph to make it more stable/faster) process: optimization (modify the graph to make it more stable/faster)
......
...@@ -2,7 +2,7 @@ Function profiling ...@@ -2,7 +2,7 @@ Function profiling
================== ==================
Message: None Message: None
Time in 1 calls to Function.__call__: 5.698204e-05s Time in 1 calls to Function.__call__: 5.698204e-05s
Time in Function.fn.__call__: 1.192093e-05s (20.921%) Time in Function.vm.__call__: 1.192093e-05s (20.921%)
Time in thunks: 6.198883e-06s (10.879%) Time in thunks: 6.198883e-06s (10.879%)
Total compile time: 3.642474e+00s Total compile time: 3.642474e+00s
Aesara Optimizer time: 7.326508e-02s Aesara Optimizer time: 7.326508e-02s
......
...@@ -346,8 +346,8 @@ class TestFunction: ...@@ -346,8 +346,8 @@ class TestFunction:
cpy = ori.copy(share_memory=True) cpy = ori.copy(share_memory=True)
# Test if memories shared # Test if memories shared
storage_map_ori = ori.fn.storage_map storage_map_ori = ori.vm.storage_map
storage_map_cpy = cpy.fn.storage_map storage_map_cpy = cpy.vm.storage_map
fgraph_cpy = cpy.maker.fgraph fgraph_cpy = cpy.maker.fgraph
# Assert intermediate and Constants storages are shared. # Assert intermediate and Constants storages are shared.
...@@ -424,11 +424,11 @@ class TestFunction: ...@@ -424,11 +424,11 @@ class TestFunction:
# 2. SharedVariable is updatable -> values did update(z == 5) # 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage # 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys() names = map_SV.keys()
for key in cpy.fn.storage_map: for key in cpy.vm.storage_map:
if key.name in names: if key.name in names:
assert ( assert (
map_SV[key.name].container.storage[0] map_SV[key.name].container.storage[0]
== cpy.fn.storage_map[key][0] == cpy.vm.storage_map[key][0]
) )
second_time = True second_time = True
...@@ -688,18 +688,18 @@ class TestFunction: ...@@ -688,18 +688,18 @@ class TestFunction:
x = vector("x") x = vector("x")
func = function([x], x + 1) func = function([x], x + 1)
func.fn.allow_gc = False func.vm.allow_gc = False
func([1]) func([1])
check_list = [] check_list = []
for key, val in func.fn.storage_map.items(): for key, val in func.vm.storage_map.items():
if not isinstance(key, Constant): if not isinstance(key, Constant):
check_list.append(val) check_list.append(val)
assert any(val[0] for val in check_list) assert any(val[0] for val in check_list)
func.free() func.free()
for key, val in func.fn.storage_map.items(): for key, val in func.vm.storage_map.items():
if not isinstance(key, Constant): if not isinstance(key, Constant):
assert val[0] is None assert val[0] is None
......
...@@ -3505,7 +3505,7 @@ def test_config_options_parallel(): ...@@ -3505,7 +3505,7 @@ def test_config_options_parallel():
with config.change_flags(numba__vectorize_target="parallel"): with config.change_flags(numba__vectorize_target="parallel"):
aesara_numba_fn = function([x], x * 2, mode=numba_mode) aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"] numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert numba_mul_fn.targetoptions["parallel"] is True assert numba_mul_fn.targetoptions["parallel"] is True
...@@ -3514,7 +3514,7 @@ def test_config_options_fastmath(): ...@@ -3514,7 +3514,7 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True): with config.change_flags(numba__fastmath=True):
aesara_numba_fn = function([x], x * 2, mode=numba_mode) aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"] numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert numba_mul_fn.targetoptions["fastmath"] is True assert numba_mul_fn.targetoptions["fastmath"] is True
...@@ -3523,12 +3523,12 @@ def test_config_options_cached(): ...@@ -3523,12 +3523,12 @@ def test_config_options_cached():
with config.change_flags(numba__cache=True): with config.change_flags(numba__cache=True):
aesara_numba_fn = function([x], x * 2, mode=numba_mode) aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"] numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert not isinstance( assert not isinstance(
numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache
) )
with config.change_flags(numba__cache=False): with config.change_flags(numba__cache=False):
aesara_numba_fn = function([x], x * 2, mode=numba_mode) aesara_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = aesara_numba_fn.fn.jit_fn.py_func.__globals__["mul"] numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache)
...@@ -52,11 +52,11 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals): ...@@ -52,11 +52,11 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
assert np.array_equal(numba_res, numpy_res) assert np.array_equal(numba_res, numpy_res)
# FYI: To test the Numba JITed function directly, use `aesara_numba_fn.fn.jit_fn` # FYI: To test the Numba JITed function directly, use `aesara_numba_fn.vm.jit_fn`
numpy_timer = timeit.Timer("numpy_fn(*input_vals)", "pass", globals=locals()) numpy_timer = timeit.Timer("numpy_fn(*input_vals)", "pass", globals=locals())
numba_timer = timeit.Timer( numba_timer = timeit.Timer(
"aesara_numba_fn.fn.jit_fn(*input_vals)", "pass", globals=locals() "aesara_numba_fn.vm.jit_fn(*input_vals)", "pass", globals=locals()
) )
# c_timer = timeit.Timer("aesara_c_fn(*input_vals)", "pass", globals=locals()) # c_timer = timeit.Timer("aesara_c_fn(*input_vals)", "pass", globals=locals())
......
...@@ -86,7 +86,7 @@ def test_use_c_thunks(): ...@@ -86,7 +86,7 @@ def test_use_c_thunks():
), ),
) )
assert np.array_equal(a * b, f(a, b)) assert np.array_equal(a * b, f(a, b))
assert any(hasattr(t, "cthunk") for t in f.fn.thunks) == use_c_thunks assert any(hasattr(t, "cthunk") for t in f.vm.thunks) == use_c_thunks
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -215,9 +215,9 @@ def test_partial_function(linker): ...@@ -215,9 +215,9 @@ def test_partial_function(linker):
if linker == "cvm": if linker == "cvm":
from aesara.link.c.cvm import CVM from aesara.link.c.cvm import CVM
assert isinstance(f.fn, CVM) assert isinstance(f.vm, CVM)
else: else:
assert isinstance(f.fn, Stack) assert isinstance(f.vm, Stack)
assert f(3, output_subset=[0, 1, 2]) == f(3) assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]] assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
...@@ -277,17 +277,17 @@ def test_allow_gc_cvm(): ...@@ -277,17 +277,17 @@ def test_allow_gc_cvm():
f([1]) f([1])
n = list(f.maker.fgraph.apply_nodes)[0].outputs[0] n = list(f.maker.fgraph.apply_nodes)[0].outputs[0]
assert f.fn.storage_map[n][0] is None assert f.vm.storage_map[n][0] is None
assert f.fn.allow_gc is True assert f.vm.allow_gc is True
f.fn.allow_gc = False f.vm.allow_gc = False
assert f.fn.allow_gc is False assert f.vm.allow_gc is False
f([1]) f([1])
assert f.fn.storage_map[n][0] is not None assert f.vm.storage_map[n][0] is not None
f.fn.allow_gc = True f.vm.allow_gc = True
assert f.fn.allow_gc is True assert f.vm.allow_gc is True
f([1]) f([1])
assert f.fn.storage_map[n][0] is None assert f.vm.storage_map[n][0] is None
class RunOnce(Op): class RunOnce(Op):
...@@ -334,7 +334,7 @@ def test_reallocation(): ...@@ -334,7 +334,7 @@ def test_reallocation():
f = function([x, y], z, name="test_reduce_memory", mode=m) f = function([x, y], z, name="test_reduce_memory", mode=m)
output = f(1, 2) output = f(1, 2)
assert output assert output
storage_map = f.fn.storage_map storage_map = f.vm.storage_map
def check_storage(storage_map): def check_storage(storage_map):
for i in storage_map: for i in storage_map:
...@@ -365,8 +365,8 @@ def test_no_recycling(): ...@@ -365,8 +365,8 @@ def test_no_recycling():
mode = Mode(optimizer="fast_compile", linker=lnk) mode = Mode(optimizer="fast_compile", linker=lnk)
f = function([x], x + 1, mode=mode) f = function([x], x + 1, mode=mode)
f2 = function([x], (x + 1) * 2, mode=mode) f2 = function([x], (x + 1) * 2, mode=mode)
m1 = f.fn.thunks[0].thunk.module m1 = f.vm.thunks[0].thunk.module
m2 = f2.fn.thunks[0].thunk.module m2 = f2.vm.thunks[0].thunk.module
assert m1 is m2 assert m1 is m2
...@@ -381,7 +381,7 @@ def test_VMLinker_make_vm_cvm(): ...@@ -381,7 +381,7 @@ def test_VMLinker_make_vm_cvm():
linker = VMLinker(allow_gc=False, use_cloop=True) linker = VMLinker(allow_gc=False, use_cloop=True)
f = function([a], a, mode=Mode(optimizer=None, linker=linker)) f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, CVM) assert isinstance(f.vm, CVM)
def test_VMLinker_make_vm_no_cvm(): def test_VMLinker_make_vm_no_cvm():
...@@ -405,7 +405,7 @@ def test_VMLinker_make_vm_no_cvm(): ...@@ -405,7 +405,7 @@ def test_VMLinker_make_vm_no_cvm():
import aesara.link.c.cvm import aesara.link.c.cvm
f = function([a], a, mode=Mode(optimizer=None, linker=linker)) f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, Loop) assert isinstance(f.vm, Loop)
def test_VMLinker_exception(): def test_VMLinker_exception():
......
...@@ -916,7 +916,7 @@ def test_multMatVect(): ...@@ -916,7 +916,7 @@ def test_multMatVect():
r_a1 = rng_mrg.matVecModM(A1, s1, m1) r_a1 = rng_mrg.matVecModM(A1, s1, m1)
r_a2 = rng_mrg.matVecModM(A2, s2, m2) r_a2 = rng_mrg.matVecModM(A2, s2, m2)
f0.fn() f0.vm()
r_b = f0.output_storage[0].value r_b = f0.output_storage[0].value
assert np.allclose(r_a1, r_b[:3]) assert np.allclose(r_a1, r_b[:3])
......
...@@ -2702,8 +2702,8 @@ def test_profile_info(): ...@@ -2702,8 +2702,8 @@ def test_profile_info():
assert profile.callcount == 0 assert profile.callcount == 0
assert profile.nbsteps == 0 assert profile.nbsteps == 0
assert profile.call_time == 0.0 assert profile.call_time == 0.0
assert fn.fn.call_times == [0.0] assert fn.vm.call_times == [0.0]
assert fn.fn.call_counts == [0] assert fn.vm.call_counts == [0]
z_fn = function([], z) z_fn = function([], z)
...@@ -2716,8 +2716,8 @@ def test_profile_info(): ...@@ -2716,8 +2716,8 @@ def test_profile_info():
# Confirm that `VM.update_profile` was called # Confirm that `VM.update_profile` was called
assert profile.apply_time assert profile.apply_time
assert fn.fn.call_times == [0.0] assert fn.vm.call_times == [0.0]
assert fn.fn.call_counts == [0] assert fn.vm.call_counts == [0]
class TestExamples: class TestExamples:
......
...@@ -616,7 +616,7 @@ class TestConv2D(utt.InferShapeTester): ...@@ -616,7 +616,7 @@ class TestConv2D(utt.InferShapeTester):
) )
aesara_conv = aesara.function([], output, mode=mode) aesara_conv = aesara.function([], output, mode=mode)
t1 = time.time() t1 = time.time()
aesara_conv.fn(n_calls=n_calls) aesara_conv.vm(n_calls=n_calls)
t2 = time.time() t2 = time.time()
print(t2 - t1, end=" ") print(t2 - t1, end=" ")
print() print()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论