提交 77bcd689 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Avoid inner-function compilation during Scan construction

Closes #709
上级 74f80840
...@@ -372,25 +372,7 @@ def pfunc( ...@@ -372,25 +372,7 @@ def pfunc(
equivalent to Var1. equivalent to Var1.
""" """
#
# This function works by cloning the graph (except for the
# inputs), and then shipping it off to aesara.compile.function.function
# (There it will be cloned again, unnecessarily, because it doesn't know
# that we already cloned it.)
#
# First, it clones the replacements named in the givens argument,
# and points each Var1 to the clone of Var2. Then it sets the
# inputs in the clone dictionary. After these steps, we are
# assuming that the clone dictionary contains all the inputs to
# the computation graph.
#
# Then it clones the outputs and the update expressions. This
# rebuilds a computation graph from the inputs and the givens.
#
if updates is None:
updates = []
if givens is None:
givens = []
if profile is None: if profile is None:
profile = config.profile or config.print_global_stats profile = config.profile or config.print_global_stats
# profile -> True or False # profile -> True or False
...@@ -405,6 +387,62 @@ def pfunc( ...@@ -405,6 +387,62 @@ def pfunc(
# No need to block other objects being passed through though. It might be # No need to block other objects being passed through though. It might be
# useful. # useful.
inputs, cloned_outputs = construct_pfunc_ins_and_outs(
params,
outputs,
mode,
updates,
givens,
no_default_updates,
rebuild_strict,
allow_input_downcast,
)
return orig_function(
inputs,
cloned_outputs,
mode,
accept_inplace=accept_inplace,
name=name,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
)
def construct_pfunc_ins_and_outs(
params,
outputs=None,
mode=None,
updates=None,
givens=None,
no_default_updates=False,
rebuild_strict=True,
allow_input_downcast=None,
):
"""Construct inputs and outputs for `pfunc`.
This function works by cloning the graph (except for the
inputs), and then shipping it off to aesara.compile.function.function
(There it will be cloned again, unnecessarily, because it doesn't know
that we already cloned it.)
First, it clones the replacements named in the `givens` argument,
and points each ``Var1`` to the clone of ``Var2``. Then it sets the
inputs in the clone dictionary. After these steps, we are
assuming that the clone dictionary contains all the inputs to
the computation graph.
Then it clones the outputs and the update expressions. This
rebuilds a computation graph from the inputs and the `givens`.
"""
if updates is None:
updates = []
if givens is None:
givens = []
if not isinstance(params, (list, tuple)): if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or " "a tuple") raise Exception("in pfunc() the first argument must be a list or " "a tuple")
...@@ -520,16 +558,7 @@ def pfunc( ...@@ -520,16 +558,7 @@ def pfunc(
) )
inputs.append(si) inputs.append(si)
return orig_function( return inputs, cloned_outputs
inputs,
cloned_outputs,
mode,
accept_inplace=accept_inplace,
name=name,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
...@@ -5,8 +5,7 @@ import numpy as np ...@@ -5,8 +5,7 @@ import numpy as np
import aesara.tensor as at import aesara.tensor as at
from aesara.compile import SharedVariable from aesara.compile import SharedVariable
from aesara.compile.function import function from aesara.compile.function.pfunc import construct_pfunc_ins_and_outs
from aesara.compile.mode import Mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs
from aesara.graph.fg import MissingInputError from aesara.graph.fg import MissingInputError
...@@ -764,9 +763,6 @@ def scan( ...@@ -764,9 +763,6 @@ def scan(
# we have and what are their update rules (note that the user has # we have and what are their update rules (note that the user has
# the option not to pass the shared variable to scan, so we need to # the option not to pass the shared variable to scan, so we need to
# pick them manually and add them to scan) # pick them manually and add them to scan)
# make the compilation as fast as possible by not applying any
# optimization or conversion to C [ note this region is not important
# for performance so we can do stuff as unoptimal as we wish ]
# extract still missing inputs (there still might be so) and add them # extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args # as non sequences at the end of our args
...@@ -794,13 +790,8 @@ def scan( ...@@ -794,13 +790,8 @@ def scan(
# Perform a try-except to provide a meaningful error message to the # Perform a try-except to provide a meaningful error message to the
# user if inputs of the inner function are missing. # user if inputs of the inner function are missing.
try: try:
dummy_f = function( dummy_inputs, dummy_outputs = construct_pfunc_ins_and_outs(
dummy_args, dummy_args, dummy_outs, updates=updates
dummy_outs,
updates=updates,
mode=Mode(linker="py", optimizer=None),
on_unused_input="ignore",
profile=False,
) )
except MissingInputError as err: except MissingInputError as err:
msg = ( msg = (
...@@ -820,7 +811,7 @@ def scan( ...@@ -820,7 +811,7 @@ def scan(
# assumed outputs until now (provided by the user) there can be # assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the # only one explanation: No information is provided for any of the
# outputs (i.e. we are dealing with a map) # outputs (i.e. we are dealing with a map)
tmp_dummy_f_outs = len(dummy_f.maker.outputs) tmp_dummy_f_outs = len(dummy_outputs)
if as_while: if as_while:
tmp_dummy_f_outs -= 1 tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []): if not (tmp_dummy_f_outs == n_outs or outs_info == []):
...@@ -831,7 +822,7 @@ def scan( ...@@ -831,7 +822,7 @@ def scan(
) )
if outs_info == []: if outs_info == []:
n_outs = len(dummy_f.maker.outputs) n_outs = len(dummy_outputs)
if as_while: if as_while:
n_outs = n_outs - 1 n_outs = n_outs - 1
outs_info = [OrderedDict() for x in range(n_outs)] outs_info = [OrderedDict() for x in range(n_outs)]
...@@ -854,7 +845,7 @@ def scan( ...@@ -854,7 +845,7 @@ def scan(
shared_inner_inputs = [] shared_inner_inputs = []
shared_inner_outputs = [] shared_inner_outputs = []
sit_sot_shared = [] sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs: for input in dummy_inputs:
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None: if getattr(input.variable, "name", None) is not None:
...@@ -926,7 +917,7 @@ def scan( ...@@ -926,7 +917,7 @@ def scan(
other_shared_scan_args = [ other_shared_scan_args = [
arg.variable arg.variable
for arg in dummy_f.maker.expanded_inputs for arg in dummy_inputs
if ( if (
isinstance(arg.variable, SharedVariable) isinstance(arg.variable, SharedVariable)
and not arg.update and not arg.update
...@@ -935,7 +926,7 @@ def scan( ...@@ -935,7 +926,7 @@ def scan(
] ]
other_shared_inner_args = [ other_shared_inner_args = [
safe_new(arg.variable, "_copy") safe_new(arg.variable, "_copy")
for arg in dummy_f.maker.expanded_inputs for arg in dummy_inputs
if ( if (
isinstance(arg.variable, SharedVariable) isinstance(arg.variable, SharedVariable)
and not arg.update and not arg.update
...@@ -945,12 +936,12 @@ def scan( ...@@ -945,12 +936,12 @@ def scan(
else: else:
other_shared_scan_args = [ other_shared_scan_args = [
arg.variable arg.variable
for arg in dummy_f.maker.expanded_inputs for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update) if (isinstance(arg.variable, SharedVariable) and not arg.update)
] ]
other_shared_inner_args = [ other_shared_inner_args = [
safe_new(arg.variable, "_copy") safe_new(arg.variable, "_copy")
for arg in dummy_f.maker.expanded_inputs for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update) if (isinstance(arg.variable, SharedVariable) and not arg.update)
] ]
givens.update(OrderedDict(zip(other_shared_scan_args, other_shared_inner_args))) givens.update(OrderedDict(zip(other_shared_scan_args, other_shared_inner_args)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论