提交 7bc40c67 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove inherit_stack_trace, nodes_constructed, and Variable.*_construction_observer methods

上级 16d2e934
...@@ -48,7 +48,7 @@ from aesara.gpuarray.optdb import ( ...@@ -48,7 +48,7 @@ from aesara.gpuarray.optdb import (
) )
from aesara.gpuarray.reduction import GpuMaxAndArgmax from aesara.gpuarray.reduction import GpuMaxAndArgmax
from aesara.gpuarray.type import list_contexts from aesara.gpuarray.type import list_contexts
from aesara.graph.opt import GlobalOptimizer, inherit_stack_trace, local_optimizer from aesara.graph.opt import GlobalOptimizer, copy_stack_trace, local_optimizer
from aesara.scalar import Log from aesara.scalar import Log
from aesara.tensor.math import Argmax from aesara.tensor.math import Argmax
from aesara.tensor.nnet.abstract_conv import ( from aesara.tensor.nnet.abstract_conv import (
...@@ -79,15 +79,17 @@ def local_abstractconv_cudnn(fgraph, node): ...@@ -79,15 +79,17 @@ def local_abstractconv_cudnn(fgraph, node):
# Asymmetric padding not yet supported # Asymmetric padding not yet supported
return None return None
if isinstance(node.op, AbstractConv2d): if isinstance(node.op, AbstractConv2d):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv_cudnn_graph(
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs node.op, ctx, node.inputs, node.outputs
) )
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d): elif isinstance(node.op, AbstractConv3d):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv3d_cudnn_graph(
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs node.op, ctx, node.inputs, node.outputs
) )
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer( @local_optimizer(
...@@ -362,15 +364,17 @@ def local_abstractconv_gw_cudnn(fgraph, node): ...@@ -362,15 +364,17 @@ def local_abstractconv_gw_cudnn(fgraph, node):
# Asymmetric padding not yet supported # Asymmetric padding not yet supported
return None return None
if isinstance(node.op, AbstractConv2d_gradWeights): if isinstance(node.op, AbstractConv2d_gradWeights):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv_cudnn_graph(
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs node.op, ctx, node.inputs, node.outputs
) )
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradWeights): elif isinstance(node.op, AbstractConv3d_gradWeights):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv3d_cudnn_graph(
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs node.op, ctx, node.inputs, node.outputs
) )
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs]) @local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs])
...@@ -386,15 +390,17 @@ def local_abstractconv_gi_cudnn(fgraph, node): ...@@ -386,15 +390,17 @@ def local_abstractconv_gi_cudnn(fgraph, node):
# Asymmetric padding not yet supported # Asymmetric padding not yet supported
return None return None
if isinstance(node.op, AbstractConv2d_gradInputs): if isinstance(node.op, AbstractConv2d_gradInputs):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv_cudnn_graph(
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs node.op, ctx, node.inputs, node.outputs
) )
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradInputs): elif isinstance(node.op, AbstractConv3d_gradInputs):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv3d_cudnn_graph(
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs node.op, ctx, node.inputs, node.outputs
) )
copy_stack_trace(node.outputs, new_out)
return new_out
@inplace_allocempty(GpuDnnConv, 2) @inplace_allocempty(GpuDnnConv, 2)
...@@ -748,11 +754,12 @@ def local_dnn_reduction(fgraph, node): ...@@ -748,11 +754,12 @@ def local_dnn_reduction(fgraph, node):
if not cudnn.cudnnReduceTensorOp_t.has_alias(scal): if not cudnn.cudnnReduceTensorOp_t.has_alias(scal):
return return
with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)( ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)(
node.inputs[0] node.inputs[0]
) )
return [post(ret)] new_out = [post(ret)]
copy_stack_trace(node.outputs, new_out)
return new_out
@register_opt("cudnn") @register_opt("cudnn")
......
...@@ -155,7 +155,6 @@ from aesara.graph.opt import ( ...@@ -155,7 +155,6 @@ from aesara.graph.opt import (
GlobalOptimizer, GlobalOptimizer,
LocalMetaOptimizer, LocalMetaOptimizer,
copy_stack_trace, copy_stack_trace,
inherit_stack_trace,
local_optimizer, local_optimizer,
) )
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
...@@ -408,10 +407,9 @@ class GraphToGPU(GlobalOptimizer): ...@@ -408,10 +407,9 @@ class GraphToGPU(GlobalOptimizer):
outputs = [] outputs = []
if isinstance(new_ops, aesara.graph.op.Op): if isinstance(new_ops, aesara.graph.op.Op):
with inherit_stack_trace(node.outputs): outputs = new_ops(*[mapping[i] for i in node.inputs], return_list=True)
outputs = new_ops( copy_stack_trace(node.outputs, outputs)
*[mapping[i] for i in node.inputs], return_list=True
)
elif not new_ops: elif not new_ops:
newnode = node.clone_with_new_inputs( newnode = node.clone_with_new_inputs(
[mapping.get(i) for i in node.inputs] [mapping.get(i) for i in node.inputs]
...@@ -685,8 +683,9 @@ def local_gpualloc_memset_0(fgraph, node): ...@@ -685,8 +683,9 @@ def local_gpualloc_memset_0(fgraph, node):
and (np.asarray(inp.data) == 0).all() and (np.asarray(inp.data) == 0).all()
): ):
new_op = GpuAlloc(node.op.context_name, memset_0=True) new_op = GpuAlloc(node.op.context_name, memset_0=True)
with inherit_stack_trace(node.outputs): new_out = new_op(*node.inputs, return_list=True)
return new_op(*node.inputs, return_list=True) copy_stack_trace(node.outputs, new_out)
return new_out
# Don't register by default. # Don't register by default.
...@@ -695,12 +694,11 @@ def local_gpua_alloc_empty_to_zeros(fgraph, node): ...@@ -695,12 +694,11 @@ def local_gpua_alloc_empty_to_zeros(fgraph, node):
if isinstance(node.op, GpuAllocEmpty): if isinstance(node.op, GpuAllocEmpty):
context_name = infer_context_name(*node.inputs) context_name = infer_context_name(*node.inputs)
z = np.asarray(0, dtype=node.outputs[0].dtype) z = np.asarray(0, dtype=node.outputs[0].dtype)
with inherit_stack_trace(node.outputs): new_out = [
return [ GpuAlloc(context_name)(as_gpuarray_variable(z, context_name), *node.inputs)
GpuAlloc(context_name)(
as_gpuarray_variable(z, context_name), *node.inputs
)
] ]
copy_stack_trace(node.outputs, new_out)
return new_out
optdb.register( optdb.register(
...@@ -944,11 +942,12 @@ def gpu_print_wrapper(op, cnda): ...@@ -944,11 +942,12 @@ def gpu_print_wrapper(op, cnda):
@register_opt2([aesara.printing.Print], "fast_compile") @register_opt2([aesara.printing.Print], "fast_compile")
def local_gpua_print_op(fgraph, op, context_name, inputs, outputs): def local_gpua_print_op(fgraph, op, context_name, inputs, outputs):
(x,) = inputs (x,) = inputs
with inherit_stack_trace(outputs):
gpu_x = as_gpuarray_variable(x, context_name=context_name) gpu_x = as_gpuarray_variable(x, context_name=context_name)
new_op = op.__class__(global_fn=gpu_print_wrapper) new_op = op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = op new_op.old_op = op
return new_op(gpu_x) new_out = new_op(gpu_x)
copy_stack_trace(outputs, new_out)
return new_out
@register_opt("fast_compile") @register_opt("fast_compile")
...@@ -1002,7 +1001,6 @@ def local_gpu_pdbbreakpoint_op(fgraph, node): ...@@ -1002,7 +1001,6 @@ def local_gpu_pdbbreakpoint_op(fgraph, node):
return False return False
# Apply the op on the new inputs # Apply the op on the new inputs
with inherit_stack_trace(node.outputs):
new_op_outputs = node.op(*new_inputs, return_list=True) new_op_outputs = node.op(*new_inputs, return_list=True)
# Propagate the transfer to the gpu through the outputs that require # Propagate the transfer to the gpu through the outputs that require
...@@ -1014,6 +1012,7 @@ def local_gpu_pdbbreakpoint_op(fgraph, node): ...@@ -1014,6 +1012,7 @@ def local_gpu_pdbbreakpoint_op(fgraph, node):
else: else:
new_outputs.append(new_op_outputs[i]) new_outputs.append(new_op_outputs[i])
copy_stack_trace(node.outputs, new_outputs)
return new_outputs return new_outputs
return False return False
...@@ -1273,8 +1272,8 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs): ...@@ -1273,8 +1272,8 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
adtype = "float32" adtype = "float32"
greduce = op2(op.scalar_op, axis=op.axis, dtype=odtype, acc_dtype=adtype) greduce = op2(op.scalar_op, axis=op.axis, dtype=odtype, acc_dtype=adtype)
with inherit_stack_trace(outputs):
gvar = greduce(x) gvar = greduce(x)
copy_stack_trace(outputs, gvar)
# We need to have the make node called, otherwise the mask can # We need to have the make node called, otherwise the mask can
# be None # be None
if op2 is GpuCAReduceCPY or gvar.owner.op.supports_c_code( if op2 is GpuCAReduceCPY or gvar.owner.op.supports_c_code(
...@@ -1315,7 +1314,6 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs): ...@@ -1315,7 +1314,6 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
dtype=odtype, dtype=odtype,
acc_dtype=adtype, acc_dtype=adtype,
) )
with inherit_stack_trace(outputs):
reshaped_x = x.reshape(at.stack(new_in_shp)) reshaped_x = x.reshape(at.stack(new_in_shp))
gpu_reshaped_x = as_gpuarray_variable(reshaped_x, context_name) gpu_reshaped_x = as_gpuarray_variable(reshaped_x, context_name)
# We need to have the make node called, otherwise the mask can # We need to have the make node called, otherwise the mask can
...@@ -1335,6 +1333,7 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs): ...@@ -1335,6 +1333,7 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
) )
else: else:
unreshaped_reduce = reduce_reshaped_x unreshaped_reduce = reduce_reshaped_x
copy_stack_trace(outputs, unreshaped_reduce)
return [unreshaped_reduce] return [unreshaped_reduce]
...@@ -1374,7 +1373,6 @@ def local_gpua_gemm(fgraph, op, context_name, inputs, outputs): ...@@ -1374,7 +1373,6 @@ def local_gpua_gemm(fgraph, op, context_name, inputs, outputs):
def local_gpua_gemmbatch(fgraph, op, context_name, inputs, outputs): def local_gpua_gemmbatch(fgraph, op, context_name, inputs, outputs):
if inputs[0].dtype not in ("float16", "float32", "float64"): if inputs[0].dtype not in ("float16", "float32", "float64"):
return return
with inherit_stack_trace(outputs):
a, b = inputs a, b = inputs
# Since GpuGemmBatch only supports 3D inputs and output, # Since GpuGemmBatch only supports 3D inputs and output,
# we need to add broadcastable dims to the inputs, and drop # we need to add broadcastable dims to the inputs, and drop
...@@ -1401,6 +1399,7 @@ def local_gpua_gemmbatch(fgraph, op, context_name, inputs, outputs): ...@@ -1401,6 +1399,7 @@ def local_gpua_gemmbatch(fgraph, op, context_name, inputs, outputs):
) )
if len(output_dims) != 3: if len(output_dims) != 3:
out = GpuDimShuffle(out.broadcastable, output_dims)(out) out = GpuDimShuffle(out.broadcastable, output_dims)(out)
copy_stack_trace(outputs, out)
return out return out
...@@ -1461,12 +1460,13 @@ def local_gpua_dot22(fgraph, op, context_name, inputs, outputs): ...@@ -1461,12 +1460,13 @@ def local_gpua_dot22(fgraph, op, context_name, inputs, outputs):
@op_lifter([aesara.tensor.blas.Dot22Scalar]) @op_lifter([aesara.tensor.blas.Dot22Scalar])
@register_opt2([aesara.tensor.blas.Dot22Scalar], "fast_compile") @register_opt2([aesara.tensor.blas.Dot22Scalar], "fast_compile")
def local_gpua_dot22scalar(fgraph, op, context_name, inputs, outputs): def local_gpua_dot22scalar(fgraph, op, context_name, inputs, outputs):
with inherit_stack_trace(outputs):
x, y, a = inputs x, y, a = inputs
x = as_gpuarray_variable(x, context_name) x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name) y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1]) z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
return [gpugemm_no_inplace(z, a, x, y, 0)] new_out = [gpugemm_no_inplace(z, a, x, y, 0)]
copy_stack_trace(outputs, new_out)
return new_out
@register_opt("fast_compile") @register_opt("fast_compile")
...@@ -2584,9 +2584,10 @@ def local_gpu_elemwise_careduce(fgraph, node): ...@@ -2584,9 +2584,10 @@ def local_gpu_elemwise_careduce(fgraph, node):
inp = node.inputs[0].owner.inputs[0] inp = node.inputs[0].owner.inputs[0]
props = node.op._props_dict() props = node.op._props_dict()
props["pre_scalar_op"] = node.inputs[0].owner.op.scalar_op props["pre_scalar_op"] = node.inputs[0].owner.op.scalar_op
with inherit_stack_trace(node.outputs):
out = GpuCAReduceCuda(**props)(inp) out = GpuCAReduceCuda(**props)(inp)
return [out] new_out = [out]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer(None) @local_optimizer(None)
...@@ -2783,12 +2784,13 @@ def local_gpu_solve(fgraph, op, context_name, inputs, outputs): ...@@ -2783,12 +2784,13 @@ def local_gpu_solve(fgraph, op, context_name, inputs, outputs):
@local_optimizer([GpuCusolverSolve], inplace=True) @local_optimizer([GpuCusolverSolve], inplace=True)
def local_inplace_gpu_solve(fgraph, node): def local_inplace_gpu_solve(fgraph, node):
if isinstance(node.op, GpuCusolverSolve) and not node.op.inplace: if isinstance(node.op, GpuCusolverSolve) and not node.op.inplace:
with inherit_stack_trace(node.outputs): new_out = [
return [
GpuCusolverSolve( GpuCusolverSolve(
A_structure=node.op.A_structure, trans=node.op.trans, inplace=True A_structure=node.op.A_structure, trans=node.op.trans, inplace=True
)(*node.inputs) )(*node.inputs)
] ]
copy_stack_trace(node.outputs, new_out)
return new_out
# Cholesky decomposition # Cholesky decomposition
...@@ -2835,8 +2837,9 @@ register_opt2([slinalg.Solve], "fast_compile", name="matrix_ops_db2")(matrix_ops ...@@ -2835,8 +2837,9 @@ register_opt2([slinalg.Solve], "fast_compile", name="matrix_ops_db2")(matrix_ops
@local_optimizer([GpuCholesky], inplace=True) @local_optimizer([GpuCholesky], inplace=True)
def local_inplace_gpu_cholesky(fgraph, node): def local_inplace_gpu_cholesky(fgraph, node):
if isinstance(node.op, GpuCholesky) and not node.op.inplace: if isinstance(node.op, GpuCholesky) and not node.op.inplace:
with inherit_stack_trace(node.outputs): new_out = [node.op.clone_inplace()(*node.inputs)]
return [node.op.clone_inplace()(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
def local_gpu_magma_cholesky(fgraph, op, context_name, inputs, outputs): def local_gpu_magma_cholesky(fgraph, op, context_name, inputs, outputs):
...@@ -2915,8 +2918,9 @@ def local_gpu_magma_matrix_inverse(fgraph, op, context_name, inputs, outputs): ...@@ -2915,8 +2918,9 @@ def local_gpu_magma_matrix_inverse(fgraph, op, context_name, inputs, outputs):
@local_optimizer([GpuMagmaMatrixInverse]) @local_optimizer([GpuMagmaMatrixInverse])
def local_inplace_gpu_magma_matrix_inverse(fgraph, node): def local_inplace_gpu_magma_matrix_inverse(fgraph, node):
if isinstance(node.op, GpuMagmaMatrixInverse) and not node.op.inplace: if isinstance(node.op, GpuMagmaMatrixInverse) and not node.op.inplace:
with inherit_stack_trace(node.outputs): new_out = [node.op.clone_inplace()(*node.inputs)]
return [node.op.clone_inplace()(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
# Eigen decomposition of a symmetric matrix # Eigen decomposition of a symmetric matrix
......
...@@ -14,7 +14,7 @@ from aesara.gpuarray.elemwise import GpuDimShuffle, GpuElemwise ...@@ -14,7 +14,7 @@ from aesara.gpuarray.elemwise import GpuDimShuffle, GpuElemwise
from aesara.gpuarray.type import GpuArrayType, get_context, move_to_gpu from aesara.gpuarray.type import GpuArrayType, get_context, move_to_gpu
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.tensor.basic import as_tensor, cast, get_scalar_constant_value, join from aesara.tensor.basic import as_tensor, cast, get_scalar_constant_value, join
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
...@@ -213,8 +213,9 @@ def alpha_merge(cls, alpha_in, beta_in): ...@@ -213,8 +213,9 @@ def alpha_merge(cls, alpha_in, beta_in):
except NotScalarConstantError: except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in] inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in] inputs[beta_in] = lr * targ.inputs[beta_in]
with inherit_stack_trace(node.outputs): new_out = maker(targ, *inputs)
return maker(targ, *inputs) copy_stack_trace(node.outputs, new_out)
return new_out
return opt return opt
...@@ -309,8 +310,9 @@ def output_merge(cls, alpha_in, beta_in, out_in): ...@@ -309,8 +310,9 @@ def output_merge(cls, alpha_in, beta_in, out_in):
dtype = inputs[beta_in].dtype dtype = inputs[beta_in].dtype
one = aes.constant(np.asarray(1.0, dtype=dtype)) one = aes.constant(np.asarray(1.0, dtype=dtype))
inputs[beta_in] = one inputs[beta_in] = one
with inherit_stack_trace(node.outputs): new_out = maker(targ, *inputs)
return maker(targ, *inputs) copy_stack_trace(node.outputs, new_out)
return new_out
return opt return opt
...@@ -371,8 +373,9 @@ def inplace_allocempty(op, idx): ...@@ -371,8 +373,9 @@ def inplace_allocempty(op, idx):
alloc.owner.op.dtype, alloc.owner.op.context_name alloc.owner.op.dtype, alloc.owner.op.context_name
) )
inputs[idx] = alloc_op(*alloc.owner.inputs) inputs[idx] = alloc_op(*alloc.owner.inputs)
with inherit_stack_trace(node.outputs): new_out = maker(node, inputs)
return maker(node, inputs) copy_stack_trace(node.outputs, new_out)
return new_out
return opt return opt
......
"""Core graph classes.""" """Core graph classes."""
import contextlib
import warnings import warnings
from collections import deque from collections import deque
from copy import copy from copy import copy
...@@ -398,8 +397,6 @@ class Variable(Node): ...@@ -398,8 +397,6 @@ class Variable(Node):
self.auto_name = "auto_" + str(next(self.__count__)) self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self): def get_test_value(self):
"""Get the test value. """Get the test value.
...@@ -562,22 +559,6 @@ class Variable(Node): ...@@ -562,22 +559,6 @@ class Variable(Node):
d["tag"] = t d["tag"] = t
return d return d
# refer to doc in nodes_constructed.
construction_observers: List = []
@classmethod
def append_construction_observer(cls, observer):
cls.construction_observers.append(observer)
@classmethod
def remove_construction_observer(cls, observer):
cls.construction_observers.remove(observer)
@classmethod
def notify_construction_observers(cls, instance):
for observer in cls.construction_observers:
observer(instance)
class Constant(Variable): class Constant(Variable):
"""A `Variable` with a fixed `data` field. """A `Variable` with a fixed `data` field.
...@@ -1448,42 +1429,6 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: ...@@ -1448,42 +1429,6 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return False return False
@contextlib.contextmanager
def nodes_constructed():
r"""
A context manager that is used in ``inherit_stack_trace`` and keeps track
of all the newly created variable nodes inside an optimization. A list
of ``new_nodes`` is instantiated but will be filled in a lazy manner (when
``Variable.notify_construction_observers`` is called).
``observer`` is the entity that updates the ``new_nodes`` list.
``construction_observers`` is a list inside `Variable` class and contains
a list of observer functions. The observer functions inside
``construction_observers`` are only called when a `Variable` is
instantiated (where ``Variable.notify_construction_observers`` is called).
When the observer function is called, a new `Variable` is added to
the `new_nodes` list.
Parameters
----------
new_nodes
A list of all the `Variable`\s that are created inside the optimization.
yields
``new_nodes`` list.
"""
new_nodes = []
def observer(node):
new_nodes.append(node)
Variable.append_construction_observer(observer)
yield new_nodes
Variable.remove_construction_observer(observer)
def equal_computations(xs, ys, in_xs=None, in_ys=None): def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Aesara graphs represent the same computations. """Checks if Aesara graphs represent the same computations.
......
...@@ -4,7 +4,6 @@ amount of useful generic optimization tools. ...@@ -4,7 +4,6 @@ amount of useful generic optimization tools.
""" """
import abc import abc
import contextlib
import copy import copy
import functools import functools
import inspect import inspect
...@@ -29,7 +28,6 @@ from aesara.graph.basic import ( ...@@ -29,7 +28,6 @@ from aesara.graph.basic import (
Variable, Variable,
applys_between, applys_between,
io_toposort, io_toposort,
nodes_constructed,
vars_between, vars_between,
) )
from aesara.graph.features import Feature, NodeFinder from aesara.graph.features import Feature, NodeFinder
...@@ -3000,24 +2998,6 @@ def copy_stack_trace(from_var, to_var): ...@@ -3000,24 +2998,6 @@ def copy_stack_trace(from_var, to_var):
return to_var return to_var
@contextlib.contextmanager
def inherit_stack_trace(from_var):
"""
A context manager that copies the stack trace from one or more variable nodes to all
variable nodes constructed in the body. ``new_nodes`` is the list of all the newly created
variable nodes inside an optimization that is managed by ``graph.nodes_constructed``.
Parameters
----------
from_var :
`Variable` node or a list of `Variable` nodes to copy stack traces from.
"""
with nodes_constructed() as new_nodes:
yield
copy_stack_trace(from_var, new_nodes)
def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
r"""Checks if the outputs of specific `Op`\s have a stack trace. r"""Checks if the outputs of specific `Op`\s have a stack trace.
......
...@@ -152,8 +152,8 @@ from aesara.graph.op import COp, Op ...@@ -152,8 +152,8 @@ from aesara.graph.op import COp, Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
GlobalOptimizer, GlobalOptimizer,
copy_stack_trace,
in2out, in2out,
inherit_stack_trace,
local_optimizer, local_optimizer,
) )
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import SequenceDB
...@@ -1672,15 +1672,18 @@ def local_dot_to_dot22(fgraph, node): ...@@ -1672,15 +1672,18 @@ def local_dot_to_dot22(fgraph, node):
return return
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
with inherit_stack_trace(node.outputs):
if x.ndim == 2 and y.ndim == 2: if x.ndim == 2 and y.ndim == 2:
return [_dot22(*node.inputs)] new_out = [_dot22(*node.inputs)]
if x.ndim == 2 and y.ndim == 1: elif x.ndim == 2 and y.ndim == 1:
return [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
if x.ndim == 1 and y.ndim == 2: elif x.ndim == 1 and y.ndim == 2:
return [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
if x.ndim == 1 and y.ndim == 1: elif x.ndim == 1 and y.ndim == 1:
return [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
...@@ -1688,22 +1691,25 @@ def local_dot_to_dot22(fgraph, node): ...@@ -1688,22 +1691,25 @@ def local_dot_to_dot22(fgraph, node):
@local_optimizer([gemm_no_inplace], inplace=True) @local_optimizer([gemm_no_inplace], inplace=True)
def local_inplace_gemm(fgraph, node): def local_inplace_gemm(fgraph, node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
with inherit_stack_trace(node.outputs): new_out = [gemm_inplace(*node.inputs)]
return [gemm_inplace(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemv_no_inplace], inplace=True) @local_optimizer([gemv_no_inplace], inplace=True)
def local_inplace_gemv(fgraph, node): def local_inplace_gemv(fgraph, node):
if node.op == gemv_no_inplace: if node.op == gemv_no_inplace:
with inherit_stack_trace(node.outputs): new_out = [gemv_inplace(*node.inputs)]
return [gemv_inplace(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([ger], inplace=True) @local_optimizer([ger], inplace=True)
def local_inplace_ger(fgraph, node): def local_inplace_ger(fgraph, node):
if node.op == ger: if node.op == ger:
with inherit_stack_trace(node.outputs): new_out = [ger_destructive(*node.inputs)]
return [ger_destructive(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
...@@ -1711,13 +1717,16 @@ def local_gemm_to_gemv(fgraph, node): ...@@ -1711,13 +1717,16 @@ def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV.""" """GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
with inherit_stack_trace(node.outputs):
if z.broadcastable == x.broadcastable == (True, False): if z.broadcastable == x.broadcastable == (True, False):
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
return [r.dimshuffle("x", 0)] new_out = [r.dimshuffle("x", 0)]
if z.broadcastable == y.broadcastable == (False, True): elif z.broadcastable == y.broadcastable == (False, True):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, "x")] new_out = [r.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
...@@ -1726,7 +1735,6 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1726,7 +1735,6 @@ def local_gemm_to_ger(fgraph, node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
if x.broadcastable[1] and y.broadcastable[0]: if x.broadcastable[1] and y.broadcastable[0]:
with inherit_stack_trace(node.outputs):
# x and y are both vectors so this might qualifies for a GER # x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0) xv = x.dimshuffle(0)
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
...@@ -1738,15 +1746,17 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1738,15 +1746,17 @@ def local_gemm_to_ger(fgraph, node):
if bval == 1: # best case a natural GER if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv) rval = ger(z, a, xv, yv)
return [rval] new_out = [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype) zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv) rval = ger(zeros, a, xv, yv)
return [rval] new_out = [rval]
else: else:
# if bval is another constant, then z is being usefully # if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job. # pre-scaled and GER isn't really the right tool for the job.
return return
copy_stack_trace(node.outputs, new_out)
return new_out
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
...@@ -1755,7 +1765,6 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1755,7 +1765,6 @@ def local_gemm_to_ger(fgraph, node):
def local_dot22_to_ger_or_gemv(fgraph, node): def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER.""" """dot22 computing an outer-product -> GER."""
if node.op == _dot22: if node.op == _dot22:
with inherit_stack_trace(node.outputs):
x, y = node.inputs x, y = node.inputs
xb = x.broadcastable xb = x.broadcastable
yb = y.broadcastable yb = y.broadcastable
...@@ -1767,26 +1776,30 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ...@@ -1767,26 +1776,30 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv) rval = ger(zeros, one, xv, yv)
return [rval] new_out = [rval]
if xb[0] and yb[1]: elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot # x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Aesara doesn't have a sdot, but gemv is better than _dot22 # TODO: Aesara doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(1) zeros = at.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)] new_out = [rval.dimshuffle("x", 0)]
if xb[0] and not yb[0] and not yb[1]: elif xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv # x is vector, y is matrix so try gemv
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(y.shape[1]) zeros = at.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)] new_out = [rval.dimshuffle("x", 0)]
if not xb[0] and not xb[1] and yb[1]: elif not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv # x is matrix, y is vector, try gemv
yv = y.dimshuffle(0) yv = y.dimshuffle(0)
zeros = at.AllocEmpty(x.dtype)(x.shape[0]) zeros = at.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero) rval = gemv_no_inplace(zeros, one, x, yv, zero)
return [rval.dimshuffle(0, "x")] new_out = [rval.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
################################# #################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论