提交 b088cc8f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Introduce aesara.as_symbolic

`aesara.as_symbolic` is a new function that converts all eligible objects to `Variable`s. Unlike `aesara.tensor.as_tensor_variable`, `as_symbolic` will convert `None`s and `slice`s, or any other types that have equivalent Aesara `Type`s.
上级 3edbbc4e
...@@ -26,6 +26,8 @@ __docformat__ = "restructuredtext en" ...@@ -26,6 +26,8 @@ __docformat__ = "restructuredtext en"
import logging import logging
import os import os
import sys import sys
from functools import singledispatch
from typing import Any, NoReturn, Optional
aesara_logger = logging.getLogger("aesara") aesara_logger = logging.getLogger("aesara")
...@@ -76,6 +78,51 @@ change_flags = deprecated("Use aesara.config.change_flags instead!")( ...@@ -76,6 +78,51 @@ change_flags = deprecated("Use aesara.config.change_flags instead!")(
# very rarely. # very rarely.
__api_version__ = 1 __api_version__ = 1
# isort: off
from aesara.graph.basic import Variable, clone_replace
# isort: on
def as_symbolic(x: Any, name: Optional[str] = None, **kwargs) -> Variable:
"""Convert `x` into an equivalent Aesara `Variable`.
Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
name
If a new ``Variable`` instance is created, it will be named with this
string.
kwargs
Options passed to the appropriate sub-dispatch functions. For example,
`ndim` and `dtype` can be passed when `x` is an `numpy.ndarray` or
`Number` type.
Raises
------
TypeError
If `x` cannot be converted to a `Variable`.
"""
if isinstance(x, Variable):
return x
res = _as_symbolic(x, **kwargs)
res.name = name
return res
@singledispatch
def _as_symbolic(x, **kwargs) -> Variable:
from aesara.tensor import as_tensor_variable
return as_tensor_variable(x, **kwargs)
# isort: off
from aesara import scalar, tensor from aesara import scalar, tensor
from aesara.compile import ( from aesara.compile import (
In, In,
...@@ -95,6 +142,8 @@ from aesara.printing import debugprint as dprint ...@@ -95,6 +142,8 @@ from aesara.printing import debugprint as dprint
from aesara.printing import pp, pprint from aesara.printing import pp, pprint
from aesara.updates import OrderedUpdates from aesara.updates import OrderedUpdates
# isort: on
if ( if (
config.device.startswith("cuda") config.device.startswith("cuda")
...@@ -126,13 +175,16 @@ def get_scalar_constant_value(v): ...@@ -126,13 +175,16 @@ def get_scalar_constant_value(v):
return tensor.get_scalar_constant_value(v) return tensor.get_scalar_constant_value(v)
# isort: off
import aesara.tensor.random.var import aesara.tensor.random.var
from aesara.graph.basic import clone_replace
from aesara.scan import checkpoints from aesara.scan import checkpoints
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.views import foldl, foldr, map, reduce from aesara.scan.views import foldl, foldr, map, reduce
# isort: on
# Some config variables are registered by submodules. Only after all those imports # Some config variables are registered by submodules. Only after all those
# were executed, we can warn about remaining flags provided by the user through AESARA_FLAGS. # imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS.
config.warn_unused_flags() config.warn_unused_flags()
...@@ -37,7 +37,6 @@ from aesara.graph.fg import FunctionGraph, InconsistencyError ...@@ -37,7 +37,6 @@ from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.utils import AssocList from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
from aesara.raise_op import CheckAndRaise
from aesara.utils import flatten from aesara.utils import flatten
...@@ -789,6 +788,8 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -789,6 +788,8 @@ class MergeOptimizer(GlobalOptimizer):
fgraph.attach_feature(MergeFeature()) fgraph.attach_feature(MergeFeature())
def apply(self, fgraph): def apply(self, fgraph):
from aesara.raise_op import CheckAndRaise
# Constant and non-constant are now applied in the same phase. # Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way. # I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled sched = fgraph.merge_feature.scheduled
......
...@@ -9,32 +9,34 @@ from aesara.graph.op import Op ...@@ -9,32 +9,34 @@ from aesara.graph.op import Op
def as_tensor_variable( def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> Callable: ) -> Variable:
"""Convert `x` into the appropriate ``TensorType``. """Convert `x` into an equivalent `TensorVariable`.
This function is often used by ``make_node`` methods of ``Op`` subclasses This function can be used to turn ndarrays, numbers, `Scalar` instances,
to turn ndarrays, numbers, ``Scalar`` instances, ``Apply`` instances and `Apply` instances and `TensorVariable` instances into valid input list
``TensorType`` instances into valid input list elements. elements.
See `aesara.as_symbolic` for a more general conversion function.
Parameters Parameters
---------- ----------
x x
The object to be converted into a ``Variable`` type. A The object to be converted into a `Variable` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers `numpy.ndarray` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``. will be copied to make an `numpy.ndarray`.
name name
If a new ``Variable`` instance is created, it will be named with this If a new `Variable` instance is created, it will be named with this
string. string.
ndim ndim
Return a ``Variable`` with this many dimensions. Return a `Variable` with this many dimensions.
dtype dtype
The dtype to use for the resulting ``Variable``. If `x` is already The dtype to use for the resulting `Variable`. If `x` is already
a ``Variable`` type, then the dtype will not be changed. a `Variable` type, then the dtype will not be changed.
Raises Raises
------ ------
TypeError TypeError
If `x` cannot be converted to a ``TensorType`` Variable. If `x` cannot be converted to a `TensorVariable`.
""" """
return _as_tensor_variable(x, name, ndim, **kwargs) return _as_tensor_variable(x, name, ndim, **kwargs)
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
import numpy as np import numpy as np
import aesara import aesara
from aesara import _as_symbolic
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.type import Generic, Type from aesara.graph.type import Generic, Type
from aesara.tensor.type import integer_dtypes from aesara.tensor.type import integer_dtypes
...@@ -108,6 +109,15 @@ class SliceConstant(Constant): ...@@ -108,6 +109,15 @@ class SliceConstant(Constant):
SliceType.Constant = SliceConstant SliceType.Constant = SliceConstant
@_as_symbolic.register(slice)
def as_symbolic_slice(x, **kwargs):
if any(isinstance(i, Variable) for i in (x.start, x.stop, x.step)):
return make_slice(x)
return SliceConstant(slicetype, x)
class NoneTypeT(Generic): class NoneTypeT(Generic):
""" """
Inherit from Generic to have c code working. Inherit from Generic to have c code working.
...@@ -129,9 +139,12 @@ class NoneTypeT(Generic): ...@@ -129,9 +139,12 @@ class NoneTypeT(Generic):
none_type_t = NoneTypeT() none_type_t = NoneTypeT()
# This is a variable instance. It can be used only once per fgraph.
# So use NoneConst.clone() before using it in an Aesara graph.
# Use NoneConst.equals(x) to check if two variable are NoneConst.
NoneConst = Constant(none_type_t, None, name="NoneConst") NoneConst = Constant(none_type_t, None, name="NoneConst")
@_as_symbolic.register(type(None))
def as_symbolic_None(x, **kwargs):
return NoneConst
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"] __all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"]
""" This file don't test everything. It only test one past crash error.""" """ This file don't test everything. It only test one past crash error."""
import aesara import aesara
from aesara import as_symbolic
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.tensor.math import argmax from aesara.tensor.math import argmax
from aesara.tensor.type import iscalar, vector from aesara.tensor.type import iscalar, vector
from aesara.tensor.type_other import MakeSlice, NoneConst, NoneTypeT, make_slice from aesara.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
make_slice,
)
def test_make_slice_merge(): def test_make_slice_merge():
...@@ -44,3 +51,15 @@ def test_none_Constant(): ...@@ -44,3 +51,15 @@ def test_none_Constant():
kwargs = {"mode": "FAST_RUN"} kwargs = {"mode": "FAST_RUN"}
f = aesara.function([x], [y], **kwargs) f = aesara.function([x], [y], **kwargs)
pickle.loads(pickle.dumps(f)) pickle.loads(pickle.dumps(f))
def test_as_symbolic():
res = as_symbolic(None)
assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论