提交 7bc40c67 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove inherit_stack_trace, nodes_constructed, and Variable.*_construction_observer methods

上级 16d2e934
...@@ -48,7 +48,7 @@ from aesara.gpuarray.optdb import ( ...@@ -48,7 +48,7 @@ from aesara.gpuarray.optdb import (
) )
from aesara.gpuarray.reduction import GpuMaxAndArgmax from aesara.gpuarray.reduction import GpuMaxAndArgmax
from aesara.gpuarray.type import list_contexts from aesara.gpuarray.type import list_contexts
from aesara.graph.opt import GlobalOptimizer, inherit_stack_trace, local_optimizer from aesara.graph.opt import GlobalOptimizer, copy_stack_trace, local_optimizer
from aesara.scalar import Log from aesara.scalar import Log
from aesara.tensor.math import Argmax from aesara.tensor.math import Argmax
from aesara.tensor.nnet.abstract_conv import ( from aesara.tensor.nnet.abstract_conv import (
...@@ -79,15 +79,17 @@ def local_abstractconv_cudnn(fgraph, node): ...@@ -79,15 +79,17 @@ def local_abstractconv_cudnn(fgraph, node):
# Asymmetric padding not yet supported # Asymmetric padding not yet supported
return None return None
if isinstance(node.op, AbstractConv2d): if isinstance(node.op, AbstractConv2d):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv_cudnn_graph(
return local_abstractconv_cudnn_graph( node.op, ctx, node.inputs, node.outputs
node.op, ctx, node.inputs, node.outputs )
) copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d): elif isinstance(node.op, AbstractConv3d):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv3d_cudnn_graph(
return local_abstractconv3d_cudnn_graph( node.op, ctx, node.inputs, node.outputs
node.op, ctx, node.inputs, node.outputs )
) copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer( @local_optimizer(
...@@ -362,15 +364,17 @@ def local_abstractconv_gw_cudnn(fgraph, node): ...@@ -362,15 +364,17 @@ def local_abstractconv_gw_cudnn(fgraph, node):
# Asymmetric padding not yet supported # Asymmetric padding not yet supported
return None return None
if isinstance(node.op, AbstractConv2d_gradWeights): if isinstance(node.op, AbstractConv2d_gradWeights):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv_cudnn_graph(
return local_abstractconv_cudnn_graph( node.op, ctx, node.inputs, node.outputs
node.op, ctx, node.inputs, node.outputs )
) copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradWeights): elif isinstance(node.op, AbstractConv3d_gradWeights):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv3d_cudnn_graph(
return local_abstractconv3d_cudnn_graph( node.op, ctx, node.inputs, node.outputs
node.op, ctx, node.inputs, node.outputs )
) copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs]) @local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs])
...@@ -386,15 +390,17 @@ def local_abstractconv_gi_cudnn(fgraph, node): ...@@ -386,15 +390,17 @@ def local_abstractconv_gi_cudnn(fgraph, node):
# Asymmetric padding not yet supported # Asymmetric padding not yet supported
return None return None
if isinstance(node.op, AbstractConv2d_gradInputs): if isinstance(node.op, AbstractConv2d_gradInputs):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv_cudnn_graph(
return local_abstractconv_cudnn_graph( node.op, ctx, node.inputs, node.outputs
node.op, ctx, node.inputs, node.outputs )
) copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradInputs): elif isinstance(node.op, AbstractConv3d_gradInputs):
with inherit_stack_trace(node.outputs): new_out = local_abstractconv3d_cudnn_graph(
return local_abstractconv3d_cudnn_graph( node.op, ctx, node.inputs, node.outputs
node.op, ctx, node.inputs, node.outputs )
) copy_stack_trace(node.outputs, new_out)
return new_out
@inplace_allocempty(GpuDnnConv, 2) @inplace_allocempty(GpuDnnConv, 2)
...@@ -748,11 +754,12 @@ def local_dnn_reduction(fgraph, node): ...@@ -748,11 +754,12 @@ def local_dnn_reduction(fgraph, node):
if not cudnn.cudnnReduceTensorOp_t.has_alias(scal): if not cudnn.cudnnReduceTensorOp_t.has_alias(scal):
return return
with inherit_stack_trace(node.outputs): ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)(
ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)( node.inputs[0]
node.inputs[0] )
) new_out = [post(ret)]
return [post(ret)] copy_stack_trace(node.outputs, new_out)
return new_out
@register_opt("cudnn") @register_opt("cudnn")
......
差异被折叠。
...@@ -14,7 +14,7 @@ from aesara.gpuarray.elemwise import GpuDimShuffle, GpuElemwise ...@@ -14,7 +14,7 @@ from aesara.gpuarray.elemwise import GpuDimShuffle, GpuElemwise
from aesara.gpuarray.type import GpuArrayType, get_context, move_to_gpu from aesara.gpuarray.type import GpuArrayType, get_context, move_to_gpu
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.tensor.basic import as_tensor, cast, get_scalar_constant_value, join from aesara.tensor.basic import as_tensor, cast, get_scalar_constant_value, join
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
...@@ -213,8 +213,9 @@ def alpha_merge(cls, alpha_in, beta_in): ...@@ -213,8 +213,9 @@ def alpha_merge(cls, alpha_in, beta_in):
except NotScalarConstantError: except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in] inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in] inputs[beta_in] = lr * targ.inputs[beta_in]
with inherit_stack_trace(node.outputs): new_out = maker(targ, *inputs)
return maker(targ, *inputs) copy_stack_trace(node.outputs, new_out)
return new_out
return opt return opt
...@@ -309,8 +310,9 @@ def output_merge(cls, alpha_in, beta_in, out_in): ...@@ -309,8 +310,9 @@ def output_merge(cls, alpha_in, beta_in, out_in):
dtype = inputs[beta_in].dtype dtype = inputs[beta_in].dtype
one = aes.constant(np.asarray(1.0, dtype=dtype)) one = aes.constant(np.asarray(1.0, dtype=dtype))
inputs[beta_in] = one inputs[beta_in] = one
with inherit_stack_trace(node.outputs): new_out = maker(targ, *inputs)
return maker(targ, *inputs) copy_stack_trace(node.outputs, new_out)
return new_out
return opt return opt
...@@ -371,8 +373,9 @@ def inplace_allocempty(op, idx): ...@@ -371,8 +373,9 @@ def inplace_allocempty(op, idx):
alloc.owner.op.dtype, alloc.owner.op.context_name alloc.owner.op.dtype, alloc.owner.op.context_name
) )
inputs[idx] = alloc_op(*alloc.owner.inputs) inputs[idx] = alloc_op(*alloc.owner.inputs)
with inherit_stack_trace(node.outputs): new_out = maker(node, inputs)
return maker(node, inputs) copy_stack_trace(node.outputs, new_out)
return new_out
return opt return opt
......
"""Core graph classes.""" """Core graph classes."""
import contextlib
import warnings import warnings
from collections import deque from collections import deque
from copy import copy from copy import copy
...@@ -398,8 +397,6 @@ class Variable(Node): ...@@ -398,8 +397,6 @@ class Variable(Node):
self.auto_name = "auto_" + str(next(self.__count__)) self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self): def get_test_value(self):
"""Get the test value. """Get the test value.
...@@ -562,22 +559,6 @@ class Variable(Node): ...@@ -562,22 +559,6 @@ class Variable(Node):
d["tag"] = t d["tag"] = t
return d return d
# refer to doc in nodes_constructed.
construction_observers: List = []
@classmethod
def append_construction_observer(cls, observer):
cls.construction_observers.append(observer)
@classmethod
def remove_construction_observer(cls, observer):
cls.construction_observers.remove(observer)
@classmethod
def notify_construction_observers(cls, instance):
for observer in cls.construction_observers:
observer(instance)
class Constant(Variable): class Constant(Variable):
"""A `Variable` with a fixed `data` field. """A `Variable` with a fixed `data` field.
...@@ -1448,42 +1429,6 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: ...@@ -1448,42 +1429,6 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return False return False
@contextlib.contextmanager
def nodes_constructed():
r"""
A context manager that is used in ``inherit_stack_trace`` and keeps track
of all the newly created variable nodes inside an optimization. A list
of ``new_nodes`` is instantiated but will be filled in a lazy manner (when
``Variable.notify_construction_observers`` is called).
``observer`` is the entity that updates the ``new_nodes`` list.
``construction_observers`` is a list inside `Variable` class and contains
a list of observer functions. The observer functions inside
``construction_observers`` are only called when a `Variable` is
instantiated (where ``Variable.notify_construction_observers`` is called).
When the observer function is called, a new `Variable` is added to
the `new_nodes` list.
Parameters
----------
new_nodes
A list of all the `Variable`\s that are created inside the optimization.
yields
``new_nodes`` list.
"""
new_nodes = []
def observer(node):
new_nodes.append(node)
Variable.append_construction_observer(observer)
yield new_nodes
Variable.remove_construction_observer(observer)
def equal_computations(xs, ys, in_xs=None, in_ys=None): def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Aesara graphs represent the same computations. """Checks if Aesara graphs represent the same computations.
......
...@@ -4,7 +4,6 @@ amount of useful generic optimization tools. ...@@ -4,7 +4,6 @@ amount of useful generic optimization tools.
""" """
import abc import abc
import contextlib
import copy import copy
import functools import functools
import inspect import inspect
...@@ -29,7 +28,6 @@ from aesara.graph.basic import ( ...@@ -29,7 +28,6 @@ from aesara.graph.basic import (
Variable, Variable,
applys_between, applys_between,
io_toposort, io_toposort,
nodes_constructed,
vars_between, vars_between,
) )
from aesara.graph.features import Feature, NodeFinder from aesara.graph.features import Feature, NodeFinder
...@@ -3000,24 +2998,6 @@ def copy_stack_trace(from_var, to_var): ...@@ -3000,24 +2998,6 @@ def copy_stack_trace(from_var, to_var):
return to_var return to_var
@contextlib.contextmanager
def inherit_stack_trace(from_var):
"""
A context manager that copies the stack trace from one or more variable nodes to all
variable nodes constructed in the body. ``new_nodes`` is the list of all the newly created
variable nodes inside an optimization that is managed by ``graph.nodes_constructed``.
Parameters
----------
from_var :
`Variable` node or a list of `Variable` nodes to copy stack traces from.
"""
with nodes_constructed() as new_nodes:
yield
copy_stack_trace(from_var, new_nodes)
def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
r"""Checks if the outputs of specific `Op`\s have a stack trace. r"""Checks if the outputs of specific `Op`\s have a stack trace.
......
...@@ -152,8 +152,8 @@ from aesara.graph.op import COp, Op ...@@ -152,8 +152,8 @@ from aesara.graph.op import COp, Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
GlobalOptimizer, GlobalOptimizer,
copy_stack_trace,
in2out, in2out,
inherit_stack_trace,
local_optimizer, local_optimizer,
) )
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import SequenceDB
...@@ -1672,15 +1672,18 @@ def local_dot_to_dot22(fgraph, node): ...@@ -1672,15 +1672,18 @@ def local_dot_to_dot22(fgraph, node):
return return
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
with inherit_stack_trace(node.outputs): if x.ndim == 2 and y.ndim == 2:
if x.ndim == 2 and y.ndim == 2: new_out = [_dot22(*node.inputs)]
return [_dot22(*node.inputs)] elif x.ndim == 2 and y.ndim == 1:
if x.ndim == 2 and y.ndim == 1: new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
return [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] elif x.ndim == 1 and y.ndim == 2:
if x.ndim == 1 and y.ndim == 2: new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
return [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] elif x.ndim == 1 and y.ndim == 1:
if x.ndim == 1 and y.ndim == 1: new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
return [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
...@@ -1688,22 +1691,25 @@ def local_dot_to_dot22(fgraph, node): ...@@ -1688,22 +1691,25 @@ def local_dot_to_dot22(fgraph, node):
@local_optimizer([gemm_no_inplace], inplace=True) @local_optimizer([gemm_no_inplace], inplace=True)
def local_inplace_gemm(fgraph, node): def local_inplace_gemm(fgraph, node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
with inherit_stack_trace(node.outputs): new_out = [gemm_inplace(*node.inputs)]
return [gemm_inplace(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemv_no_inplace], inplace=True) @local_optimizer([gemv_no_inplace], inplace=True)
def local_inplace_gemv(fgraph, node): def local_inplace_gemv(fgraph, node):
if node.op == gemv_no_inplace: if node.op == gemv_no_inplace:
with inherit_stack_trace(node.outputs): new_out = [gemv_inplace(*node.inputs)]
return [gemv_inplace(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([ger], inplace=True) @local_optimizer([ger], inplace=True)
def local_inplace_ger(fgraph, node): def local_inplace_ger(fgraph, node):
if node.op == ger: if node.op == ger:
with inherit_stack_trace(node.outputs): new_out = [ger_destructive(*node.inputs)]
return [ger_destructive(*node.inputs)] copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
...@@ -1711,13 +1717,16 @@ def local_gemm_to_gemv(fgraph, node): ...@@ -1711,13 +1717,16 @@ def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV.""" """GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
with inherit_stack_trace(node.outputs): if z.broadcastable == x.broadcastable == (True, False):
if z.broadcastable == x.broadcastable == (True, False): r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) new_out = [r.dimshuffle("x", 0)]
return [r.dimshuffle("x", 0)] elif z.broadcastable == y.broadcastable == (False, True):
if z.broadcastable == y.broadcastable == (False, True): r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) new_out = [r.dimshuffle(0, "x")]
return [r.dimshuffle(0, "x")] else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
...@@ -1726,27 +1735,28 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1726,27 +1735,28 @@ def local_gemm_to_ger(fgraph, node):
if node.op == gemm_no_inplace: if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
if x.broadcastable[1] and y.broadcastable[0]: if x.broadcastable[1] and y.broadcastable[0]:
with inherit_stack_trace(node.outputs): # x and y are both vectors so this might qualifies for a GER
# x and y are both vectors so this might qualifies for a GER xv = x.dimshuffle(0)
xv = x.dimshuffle(0) yv = y.dimshuffle(1)
yv = y.dimshuffle(1) try:
try: bval = at.get_scalar_constant_value(b)
bval = at.get_scalar_constant_value(b) except NotScalarConstantError:
except NotScalarConstantError: # b isn't a constant, GEMM is doing useful pre-scaling
# b isn't a constant, GEMM is doing useful pre-scaling return
return
if bval == 1: # best case a natural GER
if bval == 1: # best case a natural GER rval = ger(z, a, xv, yv)
rval = ger(z, a, xv, yv) new_out = [rval]
return [rval] elif bval == 0: # GER on zeros_like should be faster than GEMM
elif bval == 0: # GER on zeros_like should be faster than GEMM zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype) rval = ger(zeros, a, xv, yv)
rval = ger(zeros, a, xv, yv) new_out = [rval]
return [rval] else:
else: # if bval is another constant, then z is being usefully
# if bval is another constant, then z is being usefully # pre-scaled and GER isn't really the right tool for the job.
# pre-scaled and GER isn't really the right tool for the job. return
return copy_stack_trace(node.outputs, new_out)
return new_out
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
...@@ -1755,38 +1765,41 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1755,38 +1765,41 @@ def local_gemm_to_ger(fgraph, node):
def local_dot22_to_ger_or_gemv(fgraph, node): def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER.""" """dot22 computing an outer-product -> GER."""
if node.op == _dot22: if node.op == _dot22:
with inherit_stack_trace(node.outputs): x, y = node.inputs
x, y = node.inputs xb = x.broadcastable
xb = x.broadcastable yb = y.broadcastable
yb = y.broadcastable one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype)) zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype)) if xb[1] and yb[0]:
if xb[1] and yb[0]: # x and y are both vectors so this might qualifies for a GER
# x and y are both vectors so this might qualifies for a GER xv = x.dimshuffle(0)
xv = x.dimshuffle(0) yv = y.dimshuffle(1)
yv = y.dimshuffle(1) zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) rval = ger(zeros, one, xv, yv)
rval = ger(zeros, one, xv, yv) new_out = [rval]
return [rval] elif xb[0] and yb[1]:
if xb[0] and yb[1]: # x and y are both vectors so this qualifies for a sdot / ddot
# x and y are both vectors so this qualifies for a sdot / ddot # TODO: Aesara doesn't have a sdot, but gemv is better than _dot22
# TODO: Aesara doesn't have a sdot, but gemv is better than _dot22 xv = x.dimshuffle(1)
xv = x.dimshuffle(1) zeros = at.AllocEmpty(x.dtype)(1)
zeros = at.AllocEmpty(x.dtype)(1) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) new_out = [rval.dimshuffle("x", 0)]
return [rval.dimshuffle("x", 0)] elif xb[0] and not yb[0] and not yb[1]:
if xb[0] and not yb[0] and not yb[1]: # x is vector, y is matrix so try gemv
# x is vector, y is matrix so try gemv xv = x.dimshuffle(1)
xv = x.dimshuffle(1) zeros = at.AllocEmpty(x.dtype)(y.shape[1])
zeros = at.AllocEmpty(x.dtype)(y.shape[1]) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) new_out = [rval.dimshuffle("x", 0)]
return [rval.dimshuffle("x", 0)] elif not xb[0] and not xb[1] and yb[1]:
if not xb[0] and not xb[1] and yb[1]: # x is matrix, y is vector, try gemv
# x is matrix, y is vector, try gemv yv = y.dimshuffle(0)
yv = y.dimshuffle(0) zeros = at.AllocEmpty(x.dtype)(x.shape[0])
zeros = at.AllocEmpty(x.dtype)(x.shape[0]) rval = gemv_no_inplace(zeros, one, x, yv, zero)
rval = gemv_no_inplace(zeros, one, x, yv, zero) new_out = [rval.dimshuffle(0, "x")]
return [rval.dimshuffle(0, "x")] else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
################################# #################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论