提交 08538340 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove Hints and use instance checks in aesara.sandbox.linalg.ops

上级 2b12a455
from aesara.sandbox.linalg.ops import psd, spectral_radius_bound from aesara.sandbox.linalg.ops import spectral_radius_bound
from aesara.tensor.nlinalg import det, eig, eigh, matrix_inverse, trace
from aesara.tensor.slinalg import cholesky, eigvalsh, solve
import logging import logging
import aesara.tensor from aesara.graph.opt import local_optimizer
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import GlobalOptimizer, local_optimizer
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor.basic_opt import ( from aesara.tensor.basic_opt import (
register_canonicalize, register_canonicalize,
...@@ -12,208 +9,16 @@ from aesara.tensor.basic_opt import ( ...@@ -12,208 +9,16 @@ from aesara.tensor.basic_opt import (
) )
from aesara.tensor.blas import Dot22 from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, Prod, dot, log from aesara.tensor.math import Dot, Prod, dot, log
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod from aesara.tensor.math import prod
from aesara.tensor.nlinalg import MatrixInverse, det, matrix_inverse, trace from aesara.tensor.nlinalg import Det, MatrixInverse, trace
from aesara.tensor.slinalg import Cholesky, Solve, cholesky, solve from aesara.tensor.slinalg import Cholesky, Solve, cholesky, solve
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Hint(Op):
"""
Provide arbitrary information to the optimizer.
These ops are removed from the graph during canonicalization
in order to not interfere with other optimizations.
The idea is that prior to canonicalization, one or more Features of the
fgraph should register the information contained in any Hint node, and
transfer that information out of the graph.
"""
__props__ = ("hints",)
def __init__(self, **kwargs):
self.hints = tuple(kwargs.items())
self.view_map = {0: [0]}
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outstor):
outstor[0][0] = inputs[0]
def grad(self, inputs, g_out):
return g_out
def hints(variable):
if variable.owner and isinstance(variable.owner.op, Hint):
return dict(variable.owner.op.hints)
else:
return {}
@register_canonicalize
@local_optimizer([Hint])
def remove_hint_nodes(fgraph, node):
if isinstance(node, Hint):
# transfer hints from graph to Feature
try:
for k, v in node.op.hints:
fgraph.hints_feature.add_hint(node.inputs[0], k, v)
except AttributeError:
pass
return node.inputs
class HintsFeature:
"""
FunctionGraph Feature to track matrix properties.
This is a similar feature to variable 'tags'. In fact, tags are one way
to provide hints.
This class exists because tags were not documented well, and the
semantics of how tag information should be moved around during
optimizations was never clearly spelled out.
Hints are assumptions about mathematical properties of variables.
If one variable is substituted for another by an optimization,
then it means that the assumptions should be transferred to the
new variable.
Hints are attached to 'positions in a graph' rather than to variables
in particular, although Hints are originally attached to a particular
positition in a graph *via* a variable in that original graph.
Examples of hints are:
- shape information
- matrix properties (e.g. symmetry, psd, banded, diagonal)
Hint information is propagated through the graph similarly to graph
optimizations, except that adding a hint does not change the graph.
Adding a hint is not something that debugmode will check.
#TODO: should a Hint be an object that can actually evaluate its
# truthfulness?
# Should the PSD property be an object that can check the
# PSD-ness of a variable?
"""
def add_hint(self, r, k, v):
logger.debug(f"adding hint; {r}, {k}, {v}")
self.hints[r][k] = v
def ensure_init_r(self, r):
if r not in self.hints:
self.hints[r] = {}
#
#
# Feature interface
#
#
def on_attach(self, fgraph):
assert not hasattr(fgraph, "hints_feature")
fgraph.hints_feature = self
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {}
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
assert r in self.hints
return
for i, r in enumerate(node.inputs + node.outputs):
# make sure we have shapes for the inputs
self.ensure_init_r(r)
def update_second_from_first(self, r0, r1):
old_hints = self.hints[r0]
new_hints = self.hints[r1]
for k, v in old_hints.items():
if k in new_hints and new_hints[k] is not v:
raise NotImplementedError()
if k not in new_hints:
new_hints[k] = v
def on_change_input(self, fgraph, node, i, r, new_r, reason):
# TODO:
# This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do.
self.ensure_init_r(new_r)
self.update_second_from_first(r, new_r)
self.update_second_from_first(new_r, r)
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
class HintsOptimizer(GlobalOptimizer):
"""
Optimizer that serves to add HintsFeature as an fgraph feature.
"""
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(HintsFeature())
def apply(self, fgraph):
pass
# -1 should make it run right before the first merge
aesara.compile.mode.optdb.register(
"HintsOpt", HintsOptimizer(), -1, "fast_run", "fast_compile"
)
def psd(v):
r"""
Apply a hint that the variable `v` is positive semi-definite, i.e.
it is a symmetric matrix and :math:`x^T A x \ge 0` for any vector x.
"""
return Hint(psd=True, symmetric=True)(v)
def is_psd(v):
return hints(v).get("psd", False)
def is_symmetric(v):
return hints(v).get("symmetric", False)
def is_positive(v):
if hints(v).get("positive", False):
return True
# TODO: how to handle this - a registry?
# infer_hints on Ops?
logger.debug(f"is_positive: {v}")
if v.owner and v.owner.op == aet_pow:
try:
exponent = aet.get_scalar_constant_value(v.owner.inputs[1])
except NotScalarConstantError:
return False
if 0 == exponent % 2:
return True
return False
@register_canonicalize @register_canonicalize
@local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def transinv_to_invtrans(fgraph, node): def transinv_to_invtrans(fgraph, node):
...@@ -229,15 +34,19 @@ def transinv_to_invtrans(fgraph, node): ...@@ -229,15 +34,19 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize @register_stabilize
@local_optimizer([Dot, Dot22]) @local_optimizer([Dot, Dot22])
def inv_as_solve(fgraph, node): def inv_as_solve(fgraph, node):
"""
This utilizes a boolean `symmetric` tag on the matrices.
"""
if isinstance(node.op, (Dot, Dot22)): if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse: if l.owner and isinstance(l.owner.op, MatrixInverse):
return [solve(l.owner.inputs[0], r)] return [solve(l.owner.inputs[0], r)]
if r.owner and r.owner.op == matrix_inverse: if r.owner and isinstance(r.owner.op, MatrixInverse):
if is_symmetric(r.owner.inputs[0]): x = r.owner.inputs[0]
return [solve(r.owner.inputs[0], l.T).T] if getattr(x.tag, "symmetric", None) is True:
return [solve(x, l.T).T]
else: else:
return [solve(r.owner.inputs[0].T, l.T).T] return [solve(x.T, l.T).T]
@register_stabilize @register_stabilize
...@@ -277,18 +86,20 @@ def tag_solve_triangular(fgraph, node): ...@@ -277,18 +86,20 @@ def tag_solve_triangular(fgraph, node):
def no_transpose_symmetric(fgraph, node): def no_transpose_symmetric(fgraph, node):
if isinstance(node.op, DimShuffle): if isinstance(node.op, DimShuffle):
x = node.inputs[0] x = node.inputs[0]
if x.type.ndim == 2 and is_symmetric(x): if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True:
# print 'UNDOING TRANSPOSE', is_symmetric(x), x.ndim
if node.op.new_order == [1, 0]: if node.op.new_order == [1, 0]:
return [x] return [x]
@register_stabilize @register_stabilize
@local_optimizer(None) # XXX: solve is defined later and can't be used here @local_optimizer([Solve])
def psd_solve_with_chol(fgraph, node): def psd_solve_with_chol(fgraph, node):
"""
This utilizes a boolean `psd` tag on matrices.
"""
if isinstance(node.op, Solve): if isinstance(node.op, Solve):
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if is_psd(A): if getattr(A.tag, "psd", None) is True:
L = cholesky(A) L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op # N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the # __if__ no other Op makes use of the the L matrix during the
...@@ -300,14 +111,14 @@ def psd_solve_with_chol(fgraph, node): ...@@ -300,14 +111,14 @@ def psd_solve_with_chol(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer(None) # XXX: det is defined later and can't be used here @local_optimizer([Det])
def local_det_chol(fgraph, node): def local_det_chol(fgraph, node):
""" """
If we have det(X) and there is already an L=cholesky(X) If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant. floating around, then we can use prod(diag(L)) to get the determinant.
""" """
if node.op == det: if isinstance(node.op, Det):
(x,) = node.inputs (x,) = node.inputs
for (cl, xpos) in fgraph.clients[x]: for (cl, xpos) in fgraph.clients[x]:
if isinstance(cl.op, Cholesky): if isinstance(cl.op, Cholesky):
...@@ -320,6 +131,9 @@ def local_det_chol(fgraph, node): ...@@ -320,6 +131,9 @@ def local_det_chol(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([log]) @local_optimizer([log])
def local_log_prod_sqr(fgraph, node): def local_log_prod_sqr(fgraph, node):
"""
This utilizes a boolean `positive` tag on matrices.
"""
if node.op == log: if node.op == log:
(x,) = node.inputs (x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod): if x.owner and isinstance(x.owner.op, Prod):
...@@ -328,29 +142,16 @@ def local_log_prod_sqr(fgraph, node): ...@@ -328,29 +142,16 @@ def local_log_prod_sqr(fgraph, node):
p = x.owner.inputs[0] p = x.owner.inputs[0]
# p is the matrix we're reducing with prod # p is the matrix we're reducing with prod
if is_positive(p): if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)] return [log(p).sum(axis=x.owner.op.axis)]
# TODO: have a reduction like prod and sum that simply # TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication. # returns the sign of the prod multiplication.
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer([log])
def local_log_pow(fgraph, node):
if node.op == log:
(x,) = node.inputs
if x.owner and x.owner.op == aet_pow:
base, exponent = x.owner.inputs
# TODO: reason to be careful with dtypes?
return [exponent * log(base)]
def spectral_radius_bound(X, log2_exponent): def spectral_radius_bound(X, log2_exponent):
""" """
Returns upper bound on the largest eigenvalue of square symmetrix matrix X. Returns upper bound on the largest eigenvalue of square symmetric matrix X.
log2_exponent must be a positive-valued integer. The larger it is, the log2_exponent must be a positive-valued integer. The larger it is, the
slower and tighter the bound. Values up to 5 should usually suffice. The slower and tighter the bound. Values up to 5 should usually suffice. The
......
...@@ -5,19 +5,11 @@ import aesara ...@@ -5,19 +5,11 @@ import aesara
from aesara import function from aesara import function
from aesara import tensor as aet from aesara import tensor as aet
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound
# The one in comment are not tested...
from aesara.sandbox.linalg.ops import Cholesky # PSD_hint,; op class
from aesara.sandbox.linalg.ops import (
Solve,
inv_as_solve,
matrix_inverse,
solve,
spectral_radius_bound,
)
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import _allclose from aesara.tensor.math import _allclose
from aesara.tensor.nlinalg import MatrixInverse from aesara.tensor.nlinalg import MatrixInverse, matrix_inverse
from aesara.tensor.slinalg import Cholesky, Solve, solve
from aesara.tensor.type import dmatrix, matrix, vector from aesara.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
...@@ -120,7 +112,7 @@ def test_spectral_radius_bound(): ...@@ -120,7 +112,7 @@ def test_spectral_radius_bound():
def test_transinv_to_invtrans(): def test_transinv_to_invtrans():
X = matrix("X") X = matrix("X")
Y = aesara.tensor.nlinalg.matrix_inverse(X) Y = matrix_inverse(X)
Z = Y.transpose() Z = Y.transpose()
f = aesara.function([X], Z) f = aesara.function([X], Z)
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论