提交 77bcd689 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Avoid inner-function compilation during Scan construction

Closes #709
上级 74f80840
......@@ -372,25 +372,7 @@ def pfunc(
equivalent to Var1.
"""
#
# This function works by cloning the graph (except for the
# inputs), and then shipping it off to aesara.compile.function.function
# (There it will be cloned again, unnecessarily, because it doesn't know
# that we already cloned it.)
#
# First, it clones the replacements named in the givens argument,
# and points each Var1 to the clone of Var2. Then it sets the
# inputs in the clone dictionary. After these steps, we are
# assuming that the clone dictionary contains all the inputs to
# the computation graph.
#
# Then it clones the outputs and the update expressions. This
# rebuilds a computation graph from the inputs and the givens.
#
if updates is None:
updates = []
if givens is None:
givens = []
if profile is None:
profile = config.profile or config.print_global_stats
# profile -> True or False
......@@ -405,6 +387,62 @@ def pfunc(
# No need to block other objects being passed through though. It might be
# useful.
inputs, cloned_outputs = construct_pfunc_ins_and_outs(
params,
outputs,
mode,
updates,
givens,
no_default_updates,
rebuild_strict,
allow_input_downcast,
)
return orig_function(
inputs,
cloned_outputs,
mode,
accept_inplace=accept_inplace,
name=name,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
)
def construct_pfunc_ins_and_outs(
params,
outputs=None,
mode=None,
updates=None,
givens=None,
no_default_updates=False,
rebuild_strict=True,
allow_input_downcast=None,
):
"""Construct inputs and outputs for `pfunc`.
This function works by cloning the graph (except for the
inputs), and then shipping it off to aesara.compile.function.function
(There it will be cloned again, unnecessarily, because it doesn't know
that we already cloned it.)
First, it clones the replacements named in the `givens` argument,
and points each ``Var1`` to the clone of ``Var2``. Then it sets the
inputs in the clone dictionary. After these steps, we are
assuming that the clone dictionary contains all the inputs to
the computation graph.
Then it clones the outputs and the update expressions. This
rebuilds a computation graph from the inputs and the `givens`.
"""
if updates is None:
updates = []
if givens is None:
givens = []
if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or " "a tuple")
......@@ -520,16 +558,7 @@ def pfunc(
)
inputs.append(si)
return orig_function(
inputs,
cloned_outputs,
mode,
accept_inplace=accept_inplace,
name=name,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
)
return inputs, cloned_outputs
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
......@@ -5,8 +5,7 @@ import numpy as np
import aesara.tensor as at
from aesara.compile import SharedVariable
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.function.pfunc import construct_pfunc_ins_and_outs
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs
from aesara.graph.fg import MissingInputError
......@@ -764,9 +763,6 @@ def scan(
# we have and what are their update rules (note that the user has
# the option not to pass the shared variable to scan, so we need to
# pick them manually and add them to scan)
# make the compilation as fast as possible by not applying any
# optimization or conversion to C [ note this region is not important
# for performance so we can do stuff as unoptimal as we wish ]
# extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args
......@@ -794,13 +790,8 @@ def scan(
# Perform a try-except to provide a meaningful error message to the
# user if inputs of the inner function are missing.
try:
dummy_f = function(
dummy_args,
dummy_outs,
updates=updates,
mode=Mode(linker="py", optimizer=None),
on_unused_input="ignore",
profile=False,
dummy_inputs, dummy_outputs = construct_pfunc_ins_and_outs(
dummy_args, dummy_outs, updates=updates
)
except MissingInputError as err:
msg = (
......@@ -820,7 +811,7 @@ def scan(
# assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the
# outputs (i.e. we are dealing with a map)
tmp_dummy_f_outs = len(dummy_f.maker.outputs)
tmp_dummy_f_outs = len(dummy_outputs)
if as_while:
tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []):
......@@ -831,7 +822,7 @@ def scan(
)
if outs_info == []:
n_outs = len(dummy_f.maker.outputs)
n_outs = len(dummy_outputs)
if as_while:
n_outs = n_outs - 1
outs_info = [OrderedDict() for x in range(n_outs)]
......@@ -854,7 +845,7 @@ def scan(
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs:
for input in dummy_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None:
......@@ -926,7 +917,7 @@ def scan(
other_shared_scan_args = [
arg.variable
for arg in dummy_f.maker.expanded_inputs
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
......@@ -935,7 +926,7 @@ def scan(
]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_f.maker.expanded_inputs
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
......@@ -945,12 +936,12 @@ def scan(
else:
other_shared_scan_args = [
arg.variable
for arg in dummy_f.maker.expanded_inputs
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_f.maker.expanded_inputs
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
]
givens.update(OrderedDict(zip(other_shared_scan_args, other_shared_inner_args)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论