提交 7bc40c67 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove inherit_stack_trace, nodes_constructed, and Variable.*_construction_observer methods

上级 16d2e934
......@@ -48,7 +48,7 @@ from aesara.gpuarray.optdb import (
)
from aesara.gpuarray.reduction import GpuMaxAndArgmax
from aesara.gpuarray.type import list_contexts
from aesara.graph.opt import GlobalOptimizer, inherit_stack_trace, local_optimizer
from aesara.graph.opt import GlobalOptimizer, copy_stack_trace, local_optimizer
from aesara.scalar import Log
from aesara.tensor.math import Argmax
from aesara.tensor.nnet.abstract_conv import (
......@@ -79,15 +79,17 @@ def local_abstractconv_cudnn(fgraph, node):
# Asymmetric padding not yet supported
return None
if isinstance(node.op, AbstractConv2d):
with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer(
......@@ -362,15 +364,17 @@ def local_abstractconv_gw_cudnn(fgraph, node):
# Asymmetric padding not yet supported
return None
if isinstance(node.op, AbstractConv2d_gradWeights):
with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradWeights):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs])
......@@ -386,15 +390,17 @@ def local_abstractconv_gi_cudnn(fgraph, node):
# Asymmetric padding not yet supported
return None
if isinstance(node.op, AbstractConv2d_gradInputs):
with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradInputs):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
@inplace_allocempty(GpuDnnConv, 2)
......@@ -748,11 +754,12 @@ def local_dnn_reduction(fgraph, node):
if not cudnn.cudnnReduceTensorOp_t.has_alias(scal):
return
with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)(
node.inputs[0]
)
return [post(ret)]
ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)(
node.inputs[0]
)
new_out = [post(ret)]
copy_stack_trace(node.outputs, new_out)
return new_out
@register_opt("cudnn")
......
差异被折叠。
......@@ -14,7 +14,7 @@ from aesara.gpuarray.elemwise import GpuDimShuffle, GpuElemwise
from aesara.gpuarray.type import GpuArrayType, get_context, move_to_gpu
from aesara.graph.basic import Constant
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.tensor.basic import as_tensor, cast, get_scalar_constant_value, join
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError
......@@ -213,8 +213,9 @@ def alpha_merge(cls, alpha_in, beta_in):
except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
new_out = maker(targ, *inputs)
copy_stack_trace(node.outputs, new_out)
return new_out
return opt
......@@ -309,8 +310,9 @@ def output_merge(cls, alpha_in, beta_in, out_in):
dtype = inputs[beta_in].dtype
one = aes.constant(np.asarray(1.0, dtype=dtype))
inputs[beta_in] = one
with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
new_out = maker(targ, *inputs)
copy_stack_trace(node.outputs, new_out)
return new_out
return opt
......@@ -371,8 +373,9 @@ def inplace_allocempty(op, idx):
alloc.owner.op.dtype, alloc.owner.op.context_name
)
inputs[idx] = alloc_op(*alloc.owner.inputs)
with inherit_stack_trace(node.outputs):
return maker(node, inputs)
new_out = maker(node, inputs)
copy_stack_trace(node.outputs, new_out)
return new_out
return opt
......
"""Core graph classes."""
import contextlib
import warnings
from collections import deque
from copy import copy
......@@ -398,8 +397,6 @@ class Variable(Node):
self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
......@@ -562,22 +559,6 @@ class Variable(Node):
d["tag"] = t
return d
# refer to doc in nodes_constructed.
construction_observers: List = []
@classmethod
def append_construction_observer(cls, observer):
cls.construction_observers.append(observer)
@classmethod
def remove_construction_observer(cls, observer):
cls.construction_observers.remove(observer)
@classmethod
def notify_construction_observers(cls, instance):
for observer in cls.construction_observers:
observer(instance)
class Constant(Variable):
"""A `Variable` with a fixed `data` field.
......@@ -1448,42 +1429,6 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return False
@contextlib.contextmanager
def nodes_constructed():
r"""
A context manager that is used in ``inherit_stack_trace`` and keeps track
of all the newly created variable nodes inside an optimization. A list
of ``new_nodes`` is instantiated but will be filled in a lazy manner (when
``Variable.notify_construction_observers`` is called).
``observer`` is the entity that updates the ``new_nodes`` list.
``construction_observers`` is a list inside `Variable` class and contains
a list of observer functions. The observer functions inside
``construction_observers`` are only called when a `Variable` is
instantiated (where ``Variable.notify_construction_observers`` is called).
When the observer function is called, a new `Variable` is added to
the `new_nodes` list.
Parameters
----------
new_nodes
A list of all the `Variable`\s that are created inside the optimization.
yields
``new_nodes`` list.
"""
new_nodes = []
def observer(node):
new_nodes.append(node)
Variable.append_construction_observer(observer)
yield new_nodes
Variable.remove_construction_observer(observer)
def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Aesara graphs represent the same computations.
......
......@@ -4,7 +4,6 @@ amount of useful generic optimization tools.
"""
import abc
import contextlib
import copy
import functools
import inspect
......@@ -29,7 +28,6 @@ from aesara.graph.basic import (
Variable,
applys_between,
io_toposort,
nodes_constructed,
vars_between,
)
from aesara.graph.features import Feature, NodeFinder
......@@ -3000,24 +2998,6 @@ def copy_stack_trace(from_var, to_var):
return to_var
@contextlib.contextmanager
def inherit_stack_trace(from_var):
"""
A context manager that copies the stack trace from one or more variable nodes to all
variable nodes constructed in the body. ``new_nodes`` is the list of all the newly created
variable nodes inside an optimization that is managed by ``graph.nodes_constructed``.
Parameters
----------
from_var :
`Variable` node or a list of `Variable` nodes to copy stack traces from.
"""
with nodes_constructed() as new_nodes:
yield
copy_stack_trace(from_var, new_nodes)
def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
r"""Checks if the outputs of specific `Op`\s have a stack trace.
......
......@@ -152,8 +152,8 @@ from aesara.graph.op import COp, Op
from aesara.graph.opt import (
EquilibriumOptimizer,
GlobalOptimizer,
copy_stack_trace,
in2out,
inherit_stack_trace,
local_optimizer,
)
from aesara.graph.optdb import SequenceDB
......@@ -1672,15 +1672,18 @@ def local_dot_to_dot22(fgraph, node):
return
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
with inherit_stack_trace(node.outputs):
if x.ndim == 2 and y.ndim == 2:
return [_dot22(*node.inputs)]
if x.ndim == 2 and y.ndim == 1:
return [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
if x.ndim == 1 and y.ndim == 2:
return [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
if x.ndim == 1 and y.ndim == 1:
return [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
if x.ndim == 2 and y.ndim == 2:
new_out = [_dot22(*node.inputs)]
elif x.ndim == 2 and y.ndim == 1:
new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
elif x.ndim == 1 and y.ndim == 2:
new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
elif x.ndim == 1 and y.ndim == 1:
new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
......@@ -1688,22 +1691,25 @@ def local_dot_to_dot22(fgraph, node):
@local_optimizer([gemm_no_inplace], inplace=True)
def local_inplace_gemm(fgraph, node):
if node.op == gemm_no_inplace:
with inherit_stack_trace(node.outputs):
return [gemm_inplace(*node.inputs)]
new_out = [gemm_inplace(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemv_no_inplace], inplace=True)
def local_inplace_gemv(fgraph, node):
if node.op == gemv_no_inplace:
with inherit_stack_trace(node.outputs):
return [gemv_inplace(*node.inputs)]
new_out = [gemv_inplace(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([ger], inplace=True)
def local_inplace_ger(fgraph, node):
if node.op == ger:
with inherit_stack_trace(node.outputs):
return [ger_destructive(*node.inputs)]
new_out = [ger_destructive(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace])
......@@ -1711,13 +1717,16 @@ def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs
with inherit_stack_trace(node.outputs):
if z.broadcastable == x.broadcastable == (True, False):
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
return [r.dimshuffle("x", 0)]
if z.broadcastable == y.broadcastable == (False, True):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, "x")]
if z.broadcastable == x.broadcastable == (True, False):
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
new_out = [r.dimshuffle("x", 0)]
elif z.broadcastable == y.broadcastable == (False, True):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
new_out = [r.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace])
......@@ -1726,27 +1735,28 @@ def local_gemm_to_ger(fgraph, node):
if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs
if x.broadcastable[1] and y.broadcastable[0]:
with inherit_stack_trace(node.outputs):
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = at.get_scalar_constant_value(b)
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv)
return [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv)
return [rval]
else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = at.get_scalar_constant_value(b)
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv)
new_out = [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv)
new_out = [rval]
else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return
copy_stack_trace(node.outputs, new_out)
return new_out
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
......@@ -1755,38 +1765,41 @@ def local_gemm_to_ger(fgraph, node):
def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER."""
if node.op == _dot22:
with inherit_stack_trace(node.outputs):
x, y = node.inputs
xb = x.broadcastable
yb = y.broadcastable
one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv)
return [rval]
if xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Aesara doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)]
if xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)]
if not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = at.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero)
return [rval.dimshuffle(0, "x")]
x, y = node.inputs
xb = x.broadcastable
yb = y.broadcastable
one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv)
new_out = [rval]
elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Aesara doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = at.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero)
new_out = [rval.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
#################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论