提交 b088cc8f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Introduce aesara.as_symbolic

`aesara.as_symbolic` is a new function that converts all eligible objects to `Variable`s. Unlike `aesara.tensor.as_tensor_variable`, `as_symbolic` will convert `None`s and `slice`s, or any other types that have equivalent Aesara `Type`s.
上级 3edbbc4e
......@@ -26,6 +26,8 @@ __docformat__ = "restructuredtext en"
import logging
import os
import sys
from functools import singledispatch
from typing import Any, NoReturn, Optional
aesara_logger = logging.getLogger("aesara")
......@@ -76,6 +78,51 @@ change_flags = deprecated("Use aesara.config.change_flags instead!")(
# very rarely.
__api_version__ = 1
# isort: off
from aesara.graph.basic import Variable, clone_replace
# isort: on
def as_symbolic(x: Any, name: Optional[str] = None, **kwargs) -> Variable:
"""Convert `x` into an equivalent Aesara `Variable`.
Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
name
If a new ``Variable`` instance is created, it will be named with this
string.
kwargs
Options passed to the appropriate sub-dispatch functions. For example,
`ndim` and `dtype` can be passed when `x` is an `numpy.ndarray` or
`Number` type.
Raises
------
TypeError
If `x` cannot be converted to a `Variable`.
"""
if isinstance(x, Variable):
return x
res = _as_symbolic(x, **kwargs)
res.name = name
return res
@singledispatch
def _as_symbolic(x, **kwargs) -> Variable:
from aesara.tensor import as_tensor_variable
return as_tensor_variable(x, **kwargs)
# isort: off
from aesara import scalar, tensor
from aesara.compile import (
In,
......@@ -95,6 +142,8 @@ from aesara.printing import debugprint as dprint
from aesara.printing import pp, pprint
from aesara.updates import OrderedUpdates
# isort: on
if (
config.device.startswith("cuda")
......@@ -126,13 +175,16 @@ def get_scalar_constant_value(v):
return tensor.get_scalar_constant_value(v)
# isort: off
import aesara.tensor.random.var
from aesara.graph.basic import clone_replace
from aesara.scan import checkpoints
from aesara.scan.basic import scan
from aesara.scan.views import foldl, foldr, map, reduce
# isort: on
# Some config variables are registered by submodules. Only after all those imports
# were executed, we can warn about remaining flags provided by the user through AESARA_FLAGS.
# Some config variables are registered by submodules. Only after all those
# imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS.
config.warn_unused_flags()
......@@ -37,7 +37,6 @@ from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import Op
from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet
from aesara.raise_op import CheckAndRaise
from aesara.utils import flatten
......@@ -789,6 +788,8 @@ class MergeOptimizer(GlobalOptimizer):
fgraph.attach_feature(MergeFeature())
def apply(self, fgraph):
from aesara.raise_op import CheckAndRaise
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled
......
......@@ -9,32 +9,34 @@ from aesara.graph.op import Op
def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> Callable:
"""Convert `x` into the appropriate ``TensorType``.
) -> Variable:
"""Convert `x` into an equivalent `TensorVariable`.
This function is often used by ``make_node`` methods of ``Op`` subclasses
to turn ndarrays, numbers, ``Scalar`` instances, ``Apply`` instances and
``TensorType`` instances into valid input list elements.
This function can be used to turn ndarrays, numbers, `Scalar` instances,
`Apply` instances and `TensorVariable` instances into valid input list
elements.
See `aesara.as_symbolic` for a more general conversion function.
Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
The object to be converted into a `Variable` type. A
`numpy.ndarray` argument will not be copied, but a list of numbers
will be copied to make an `numpy.ndarray`.
name
If a new ``Variable`` instance is created, it will be named with this
If a new `Variable` instance is created, it will be named with this
string.
ndim
Return a ``Variable`` with this many dimensions.
Return a `Variable` with this many dimensions.
dtype
The dtype to use for the resulting ``Variable``. If `x` is already
a ``Variable`` type, then the dtype will not be changed.
The dtype to use for the resulting `Variable`. If `x` is already
a `Variable` type, then the dtype will not be changed.
Raises
------
TypeError
If `x` cannot be converted to a ``TensorType`` Variable.
If `x` cannot be converted to a `TensorVariable`.
"""
return _as_tensor_variable(x, name, ndim, **kwargs)
......
......@@ -5,8 +5,9 @@
import numpy as np
import aesara
from aesara import _as_symbolic
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Generic, Type
from aesara.tensor.type import integer_dtypes
......@@ -108,6 +109,15 @@ class SliceConstant(Constant):
SliceType.Constant = SliceConstant
@_as_symbolic.register(slice)
def as_symbolic_slice(x, **kwargs):
if any(isinstance(i, Variable) for i in (x.start, x.stop, x.step)):
return make_slice(x)
return SliceConstant(slicetype, x)
class NoneTypeT(Generic):
"""
Inherit from Generic to have c code working.
......@@ -129,9 +139,12 @@ class NoneTypeT(Generic):
none_type_t = NoneTypeT()
# This is a variable instance. It can be used only once per fgraph.
# So use NoneConst.clone() before using it in an Aesara graph.
# Use NoneConst.equals(x) to check if two variable are NoneConst.
NoneConst = Constant(none_type_t, None, name="NoneConst")
@_as_symbolic.register(type(None))
def as_symbolic_None(x, **kwargs):
return NoneConst
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"]
""" This file don't test everything. It only test one past crash error."""
import aesara
from aesara import as_symbolic
from aesara.graph.basic import Constant
from aesara.tensor.math import argmax
from aesara.tensor.type import iscalar, vector
from aesara.tensor.type_other import MakeSlice, NoneConst, NoneTypeT, make_slice
from aesara.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
make_slice,
)
def test_make_slice_merge():
......@@ -44,3 +51,15 @@ def test_none_Constant():
kwargs = {"mode": "FAST_RUN"}
f = aesara.function([x], [y], **kwargs)
pickle.loads(pickle.dumps(f))
def test_as_symbolic():
res = as_symbolic(None)
assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论