提交 7bc40c67 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove inherit_stack_trace, nodes_constructed, and Variable.*_construction_observer methods

上级 16d2e934
......@@ -48,7 +48,7 @@ from aesara.gpuarray.optdb import (
)
from aesara.gpuarray.reduction import GpuMaxAndArgmax
from aesara.gpuarray.type import list_contexts
from aesara.graph.opt import GlobalOptimizer, inherit_stack_trace, local_optimizer
from aesara.graph.opt import GlobalOptimizer, copy_stack_trace, local_optimizer
from aesara.scalar import Log
from aesara.tensor.math import Argmax
from aesara.tensor.nnet.abstract_conv import (
......@@ -79,15 +79,17 @@ def local_abstractconv_cudnn(fgraph, node):
# Asymmetric padding not yet supported
return None
if isinstance(node.op, AbstractConv2d):
with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer(
......@@ -362,15 +364,17 @@ def local_abstractconv_gw_cudnn(fgraph, node):
# Asymmetric padding not yet supported
return None
if isinstance(node.op, AbstractConv2d_gradWeights):
with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradWeights):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([AbstractConv2d_gradInputs, AbstractConv3d_gradInputs])
......@@ -386,15 +390,17 @@ def local_abstractconv_gi_cudnn(fgraph, node):
# Asymmetric padding not yet supported
return None
if isinstance(node.op, AbstractConv2d_gradInputs):
with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
elif isinstance(node.op, AbstractConv3d_gradInputs):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
new_out = local_abstractconv3d_cudnn_graph(
node.op, ctx, node.inputs, node.outputs
)
copy_stack_trace(node.outputs, new_out)
return new_out
@inplace_allocempty(GpuDnnConv, 2)
......@@ -748,11 +754,12 @@ def local_dnn_reduction(fgraph, node):
if not cudnn.cudnnReduceTensorOp_t.has_alias(scal):
return
with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)(
node.inputs[0]
)
return [post(ret)]
ret = GpuDnnReduction(scal, node.op.axis, acc_dtype, node.op.dtype, False)(
node.inputs[0]
)
new_out = [post(ret)]
copy_stack_trace(node.outputs, new_out)
return new_out
@register_opt("cudnn")
......
......@@ -155,7 +155,6 @@ from aesara.graph.opt import (
GlobalOptimizer,
LocalMetaOptimizer,
copy_stack_trace,
inherit_stack_trace,
local_optimizer,
)
from aesara.ifelse import IfElse
......@@ -408,10 +407,9 @@ class GraphToGPU(GlobalOptimizer):
outputs = []
if isinstance(new_ops, aesara.graph.op.Op):
with inherit_stack_trace(node.outputs):
outputs = new_ops(
*[mapping[i] for i in node.inputs], return_list=True
)
outputs = new_ops(*[mapping[i] for i in node.inputs], return_list=True)
copy_stack_trace(node.outputs, outputs)
elif not new_ops:
newnode = node.clone_with_new_inputs(
[mapping.get(i) for i in node.inputs]
......@@ -685,8 +683,9 @@ def local_gpualloc_memset_0(fgraph, node):
and (np.asarray(inp.data) == 0).all()
):
new_op = GpuAlloc(node.op.context_name, memset_0=True)
with inherit_stack_trace(node.outputs):
return new_op(*node.inputs, return_list=True)
new_out = new_op(*node.inputs, return_list=True)
copy_stack_trace(node.outputs, new_out)
return new_out
# Don't register by default.
......@@ -695,12 +694,11 @@ def local_gpua_alloc_empty_to_zeros(fgraph, node):
if isinstance(node.op, GpuAllocEmpty):
context_name = infer_context_name(*node.inputs)
z = np.asarray(0, dtype=node.outputs[0].dtype)
with inherit_stack_trace(node.outputs):
return [
GpuAlloc(context_name)(
as_gpuarray_variable(z, context_name), *node.inputs
)
]
new_out = [
GpuAlloc(context_name)(as_gpuarray_variable(z, context_name), *node.inputs)
]
copy_stack_trace(node.outputs, new_out)
return new_out
optdb.register(
......@@ -944,11 +942,12 @@ def gpu_print_wrapper(op, cnda):
@register_opt2([aesara.printing.Print], "fast_compile")
def local_gpua_print_op(fgraph, op, context_name, inputs, outputs):
(x,) = inputs
with inherit_stack_trace(outputs):
gpu_x = as_gpuarray_variable(x, context_name=context_name)
new_op = op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = op
return new_op(gpu_x)
gpu_x = as_gpuarray_variable(x, context_name=context_name)
new_op = op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = op
new_out = new_op(gpu_x)
copy_stack_trace(outputs, new_out)
return new_out
@register_opt("fast_compile")
......@@ -1002,19 +1001,19 @@ def local_gpu_pdbbreakpoint_op(fgraph, node):
return False
# Apply the op on the new inputs
with inherit_stack_trace(node.outputs):
new_op_outputs = node.op(*new_inputs, return_list=True)
# Propagate the transfer to the gpu through the outputs that require
# it
new_outputs = []
for i in range(len(new_op_outputs)):
if input_transfered[i]:
new_outputs.append(new_op_outputs[i].transfer("cpu"))
else:
new_outputs.append(new_op_outputs[i])
new_op_outputs = node.op(*new_inputs, return_list=True)
# Propagate the transfer to the gpu through the outputs that require
# it
new_outputs = []
for i in range(len(new_op_outputs)):
if input_transfered[i]:
new_outputs.append(new_op_outputs[i].transfer("cpu"))
else:
new_outputs.append(new_op_outputs[i])
return new_outputs
copy_stack_trace(node.outputs, new_outputs)
return new_outputs
return False
......@@ -1273,8 +1272,8 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
adtype = "float32"
greduce = op2(op.scalar_op, axis=op.axis, dtype=odtype, acc_dtype=adtype)
with inherit_stack_trace(outputs):
gvar = greduce(x)
gvar = greduce(x)
copy_stack_trace(outputs, gvar)
# We need to have the make node called, otherwise the mask can
# be None
if op2 is GpuCAReduceCPY or gvar.owner.op.supports_c_code(
......@@ -1315,27 +1314,27 @@ def local_gpua_careduce(fgraph, op, context_name, inputs, outputs):
dtype=odtype,
acc_dtype=adtype,
)
with inherit_stack_trace(outputs):
reshaped_x = x.reshape(at.stack(new_in_shp))
gpu_reshaped_x = as_gpuarray_variable(reshaped_x, context_name)
# We need to have the make node called, otherwise the mask can
# be None
gvar = greduce(gpu_reshaped_x)
reshaped_gpu_inputs = [gpu_reshaped_x]
if greduce.supports_c_code(reshaped_gpu_inputs):
reduce_reshaped_x = greduce(gpu_reshaped_x)
if reduce_reshaped_x.ndim != outputs[0].ndim:
out_shp = []
for i in range(x.ndim):
if i not in op.axis:
out_shp.append(shape_i(x, i))
unreshaped_reduce = GpuReshape(len(out_shp))(
reduce_reshaped_x, at.stack(out_shp)
)
else:
unreshaped_reduce = reduce_reshaped_x
return [unreshaped_reduce]
reshaped_x = x.reshape(at.stack(new_in_shp))
gpu_reshaped_x = as_gpuarray_variable(reshaped_x, context_name)
# We need to have the make node called, otherwise the mask can
# be None
gvar = greduce(gpu_reshaped_x)
reshaped_gpu_inputs = [gpu_reshaped_x]
if greduce.supports_c_code(reshaped_gpu_inputs):
reduce_reshaped_x = greduce(gpu_reshaped_x)
if reduce_reshaped_x.ndim != outputs[0].ndim:
out_shp = []
for i in range(x.ndim):
if i not in op.axis:
out_shp.append(shape_i(x, i))
unreshaped_reduce = GpuReshape(len(out_shp))(
reduce_reshaped_x, at.stack(out_shp)
)
else:
unreshaped_reduce = reduce_reshaped_x
copy_stack_trace(outputs, unreshaped_reduce)
return [unreshaped_reduce]
@register_opt("fast_compile")
......@@ -1374,34 +1373,34 @@ def local_gpua_gemm(fgraph, op, context_name, inputs, outputs):
def local_gpua_gemmbatch(fgraph, op, context_name, inputs, outputs):
if inputs[0].dtype not in ("float16", "float32", "float64"):
return
with inherit_stack_trace(outputs):
a, b = inputs
# Since GpuGemmBatch only supports 3D inputs and output,
# we need to add broadcastable dims to the inputs, and drop
# them from outputs
output_dims = [0, 1, 2]
if a.ndim == 2:
a = GpuDimShuffle(a.broadcastable, (0, "x", 1))(a)
del output_dims[1]
if b.ndim == 2:
b = GpuDimShuffle(b.broadcastable, (0, 1, "x"))(b)
del output_dims[-1]
# In case of mismatched dtypes, we also have to upcast
out_dtype = outputs[0].dtype
if a.dtype != out_dtype or b.dtype != out_dtype:
gpu_cast_op = GpuElemwise(Cast(Scalar(out_dtype)))
if a.dtype != out_dtype:
a = gpu_cast_op(a)
if b.dtype != out_dtype:
b = gpu_cast_op(b)
c = GpuAllocEmpty(out_dtype, context_name)(a.shape[0], a.shape[1], b.shape[2])
out = gpugemmbatch_no_inplace(
c, np.asarray(1.0, dtype=out_dtype), a, b, np.asarray(0.0, dtype=out_dtype)
)
if len(output_dims) != 3:
out = GpuDimShuffle(out.broadcastable, output_dims)(out)
return out
a, b = inputs
# Since GpuGemmBatch only supports 3D inputs and output,
# we need to add broadcastable dims to the inputs, and drop
# them from outputs
output_dims = [0, 1, 2]
if a.ndim == 2:
a = GpuDimShuffle(a.broadcastable, (0, "x", 1))(a)
del output_dims[1]
if b.ndim == 2:
b = GpuDimShuffle(b.broadcastable, (0, 1, "x"))(b)
del output_dims[-1]
# In case of mismatched dtypes, we also have to upcast
out_dtype = outputs[0].dtype
if a.dtype != out_dtype or b.dtype != out_dtype:
gpu_cast_op = GpuElemwise(Cast(Scalar(out_dtype)))
if a.dtype != out_dtype:
a = gpu_cast_op(a)
if b.dtype != out_dtype:
b = gpu_cast_op(b)
c = GpuAllocEmpty(out_dtype, context_name)(a.shape[0], a.shape[1], b.shape[2])
out = gpugemmbatch_no_inplace(
c, np.asarray(1.0, dtype=out_dtype), a, b, np.asarray(0.0, dtype=out_dtype)
)
if len(output_dims) != 3:
out = GpuDimShuffle(out.broadcastable, output_dims)(out)
copy_stack_trace(outputs, out)
return out
@register_opt()
......@@ -1461,12 +1460,13 @@ def local_gpua_dot22(fgraph, op, context_name, inputs, outputs):
@op_lifter([aesara.tensor.blas.Dot22Scalar])
@register_opt2([aesara.tensor.blas.Dot22Scalar], "fast_compile")
def local_gpua_dot22scalar(fgraph, op, context_name, inputs, outputs):
with inherit_stack_trace(outputs):
x, y, a = inputs
x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
return [gpugemm_no_inplace(z, a, x, y, 0)]
x, y, a = inputs
x = as_gpuarray_variable(x, context_name)
y = as_gpuarray_variable(y, context_name)
z = GpuAllocEmpty(x.dtype, context_name)(x.shape[0], y.shape[1])
new_out = [gpugemm_no_inplace(z, a, x, y, 0)]
copy_stack_trace(outputs, new_out)
return new_out
@register_opt("fast_compile")
......@@ -2584,9 +2584,10 @@ def local_gpu_elemwise_careduce(fgraph, node):
inp = node.inputs[0].owner.inputs[0]
props = node.op._props_dict()
props["pre_scalar_op"] = node.inputs[0].owner.op.scalar_op
with inherit_stack_trace(node.outputs):
out = GpuCAReduceCuda(**props)(inp)
return [out]
out = GpuCAReduceCuda(**props)(inp)
new_out = [out]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer(None)
......@@ -2783,12 +2784,13 @@ def local_gpu_solve(fgraph, op, context_name, inputs, outputs):
@local_optimizer([GpuCusolverSolve], inplace=True)
def local_inplace_gpu_solve(fgraph, node):
if isinstance(node.op, GpuCusolverSolve) and not node.op.inplace:
with inherit_stack_trace(node.outputs):
return [
GpuCusolverSolve(
A_structure=node.op.A_structure, trans=node.op.trans, inplace=True
)(*node.inputs)
]
new_out = [
GpuCusolverSolve(
A_structure=node.op.A_structure, trans=node.op.trans, inplace=True
)(*node.inputs)
]
copy_stack_trace(node.outputs, new_out)
return new_out
# Cholesky decomposition
......@@ -2835,8 +2837,9 @@ register_opt2([slinalg.Solve], "fast_compile", name="matrix_ops_db2")(matrix_ops
@local_optimizer([GpuCholesky], inplace=True)
def local_inplace_gpu_cholesky(fgraph, node):
if isinstance(node.op, GpuCholesky) and not node.op.inplace:
with inherit_stack_trace(node.outputs):
return [node.op.clone_inplace()(*node.inputs)]
new_out = [node.op.clone_inplace()(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
def local_gpu_magma_cholesky(fgraph, op, context_name, inputs, outputs):
......@@ -2915,8 +2918,9 @@ def local_gpu_magma_matrix_inverse(fgraph, op, context_name, inputs, outputs):
@local_optimizer([GpuMagmaMatrixInverse])
def local_inplace_gpu_magma_matrix_inverse(fgraph, node):
if isinstance(node.op, GpuMagmaMatrixInverse) and not node.op.inplace:
with inherit_stack_trace(node.outputs):
return [node.op.clone_inplace()(*node.inputs)]
new_out = [node.op.clone_inplace()(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
# Eigen decomposition of a symmetric matrix
......
......@@ -14,7 +14,7 @@ from aesara.gpuarray.elemwise import GpuDimShuffle, GpuElemwise
from aesara.gpuarray.type import GpuArrayType, get_context, move_to_gpu
from aesara.graph.basic import Constant
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.tensor.basic import as_tensor, cast, get_scalar_constant_value, join
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError
......@@ -213,8 +213,9 @@ def alpha_merge(cls, alpha_in, beta_in):
except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
new_out = maker(targ, *inputs)
copy_stack_trace(node.outputs, new_out)
return new_out
return opt
......@@ -309,8 +310,9 @@ def output_merge(cls, alpha_in, beta_in, out_in):
dtype = inputs[beta_in].dtype
one = aes.constant(np.asarray(1.0, dtype=dtype))
inputs[beta_in] = one
with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
new_out = maker(targ, *inputs)
copy_stack_trace(node.outputs, new_out)
return new_out
return opt
......@@ -371,8 +373,9 @@ def inplace_allocempty(op, idx):
alloc.owner.op.dtype, alloc.owner.op.context_name
)
inputs[idx] = alloc_op(*alloc.owner.inputs)
with inherit_stack_trace(node.outputs):
return maker(node, inputs)
new_out = maker(node, inputs)
copy_stack_trace(node.outputs, new_out)
return new_out
return opt
......
"""Core graph classes."""
import contextlib
import warnings
from collections import deque
from copy import copy
......@@ -398,8 +397,6 @@ class Variable(Node):
self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
......@@ -562,22 +559,6 @@ class Variable(Node):
d["tag"] = t
return d
# refer to doc in nodes_constructed.
construction_observers: List = []
@classmethod
def append_construction_observer(cls, observer):
cls.construction_observers.append(observer)
@classmethod
def remove_construction_observer(cls, observer):
cls.construction_observers.remove(observer)
@classmethod
def notify_construction_observers(cls, instance):
for observer in cls.construction_observers:
observer(instance)
class Constant(Variable):
"""A `Variable` with a fixed `data` field.
......@@ -1448,42 +1429,6 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return False
@contextlib.contextmanager
def nodes_constructed():
r"""
A context manager that is used in ``inherit_stack_trace`` and keeps track
of all the newly created variable nodes inside an optimization. A list
of ``new_nodes`` is instantiated but will be filled in a lazy manner (when
``Variable.notify_construction_observers`` is called).
``observer`` is the entity that updates the ``new_nodes`` list.
``construction_observers`` is a list inside `Variable` class and contains
a list of observer functions. The observer functions inside
``construction_observers`` are only called when a `Variable` is
instantiated (where ``Variable.notify_construction_observers`` is called).
When the observer function is called, a new `Variable` is added to
the `new_nodes` list.
Parameters
----------
new_nodes
A list of all the `Variable`\s that are created inside the optimization.
yields
``new_nodes`` list.
"""
new_nodes = []
def observer(node):
new_nodes.append(node)
Variable.append_construction_observer(observer)
yield new_nodes
Variable.remove_construction_observer(observer)
def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Aesara graphs represent the same computations.
......
......@@ -4,7 +4,6 @@ amount of useful generic optimization tools.
"""
import abc
import contextlib
import copy
import functools
import inspect
......@@ -29,7 +28,6 @@ from aesara.graph.basic import (
Variable,
applys_between,
io_toposort,
nodes_constructed,
vars_between,
)
from aesara.graph.features import Feature, NodeFinder
......@@ -3000,24 +2998,6 @@ def copy_stack_trace(from_var, to_var):
return to_var
@contextlib.contextmanager
def inherit_stack_trace(from_var):
"""
A context manager that copies the stack trace from one or more variable nodes to all
variable nodes constructed in the body. ``new_nodes`` is the list of all the newly created
variable nodes inside an optimization that is managed by ``graph.nodes_constructed``.
Parameters
----------
from_var :
`Variable` node or a list of `Variable` nodes to copy stack traces from.
"""
with nodes_constructed() as new_nodes:
yield
copy_stack_trace(from_var, new_nodes)
def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
r"""Checks if the outputs of specific `Op`\s have a stack trace.
......
......@@ -152,8 +152,8 @@ from aesara.graph.op import COp, Op
from aesara.graph.opt import (
EquilibriumOptimizer,
GlobalOptimizer,
copy_stack_trace,
in2out,
inherit_stack_trace,
local_optimizer,
)
from aesara.graph.optdb import SequenceDB
......@@ -1672,15 +1672,18 @@ def local_dot_to_dot22(fgraph, node):
return
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
with inherit_stack_trace(node.outputs):
if x.ndim == 2 and y.ndim == 2:
return [_dot22(*node.inputs)]
if x.ndim == 2 and y.ndim == 1:
return [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
if x.ndim == 1 and y.ndim == 2:
return [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
if x.ndim == 1 and y.ndim == 1:
return [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
if x.ndim == 2 and y.ndim == 2:
new_out = [_dot22(*node.inputs)]
elif x.ndim == 2 and y.ndim == 1:
new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
elif x.ndim == 1 and y.ndim == 2:
new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
elif x.ndim == 1 and y.ndim == 1:
new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
......@@ -1688,22 +1691,25 @@ def local_dot_to_dot22(fgraph, node):
@local_optimizer([gemm_no_inplace], inplace=True)
def local_inplace_gemm(fgraph, node):
if node.op == gemm_no_inplace:
with inherit_stack_trace(node.outputs):
return [gemm_inplace(*node.inputs)]
new_out = [gemm_inplace(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemv_no_inplace], inplace=True)
def local_inplace_gemv(fgraph, node):
if node.op == gemv_no_inplace:
with inherit_stack_trace(node.outputs):
return [gemv_inplace(*node.inputs)]
new_out = [gemv_inplace(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([ger], inplace=True)
def local_inplace_ger(fgraph, node):
if node.op == ger:
with inherit_stack_trace(node.outputs):
return [ger_destructive(*node.inputs)]
new_out = [ger_destructive(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace])
......@@ -1711,13 +1717,16 @@ def local_gemm_to_gemv(fgraph, node):
"""GEMM acting on row or column matrices -> GEMV."""
if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs
with inherit_stack_trace(node.outputs):
if z.broadcastable == x.broadcastable == (True, False):
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
return [r.dimshuffle("x", 0)]
if z.broadcastable == y.broadcastable == (False, True):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, "x")]
if z.broadcastable == x.broadcastable == (True, False):
r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
new_out = [r.dimshuffle("x", 0)]
elif z.broadcastable == y.broadcastable == (False, True):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
new_out = [r.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
@local_optimizer([gemm_no_inplace])
......@@ -1726,27 +1735,28 @@ def local_gemm_to_ger(fgraph, node):
if node.op == gemm_no_inplace:
z, a, x, y, b = node.inputs
if x.broadcastable[1] and y.broadcastable[0]:
with inherit_stack_trace(node.outputs):
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = at.get_scalar_constant_value(b)
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv)
return [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv)
return [rval]
else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = at.get_scalar_constant_value(b)
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv)
new_out = [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
rval = ger(zeros, a, xv, yv)
new_out = [rval]
else:
# if bval is another constant, then z is being usefully
# pre-scaled and GER isn't really the right tool for the job.
return
copy_stack_trace(node.outputs, new_out)
return new_out
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
......@@ -1755,38 +1765,41 @@ def local_gemm_to_ger(fgraph, node):
def local_dot22_to_ger_or_gemv(fgraph, node):
"""dot22 computing an outer-product -> GER."""
if node.op == _dot22:
with inherit_stack_trace(node.outputs):
x, y = node.inputs
xb = x.broadcastable
yb = y.broadcastable
one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv)
return [rval]
if xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Aesara doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)]
if xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
return [rval.dimshuffle("x", 0)]
if not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = at.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero)
return [rval.dimshuffle(0, "x")]
x, y = node.inputs
xb = x.broadcastable
yb = y.broadcastable
one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
rval = ger(zeros, one, xv, yv)
new_out = [rval]
elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Aesara doesn't have a sdot, but gemv is better than _dot22
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif xb[0] and not yb[0] and not yb[1]:
# x is vector, y is matrix so try gemv
xv = x.dimshuffle(1)
zeros = at.AllocEmpty(x.dtype)(y.shape[1])
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
new_out = [rval.dimshuffle("x", 0)]
elif not xb[0] and not xb[1] and yb[1]:
# x is matrix, y is vector, try gemv
yv = y.dimshuffle(0)
zeros = at.AllocEmpty(x.dtype)(x.shape[0])
rval = gemv_no_inplace(zeros, one, x, yv, zero)
new_out = [rval.dimshuffle(0, "x")]
else:
return
copy_stack_trace(node.outputs, new_out)
return new_out
#################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论