提交 bd54469c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert as_tensor_variable to a singledispatch function

上级 d374dcc6
......@@ -643,11 +643,6 @@ EQ_MAP.update(list((v, k) for k, v in EQ_MAP.items()))
class _operators(_tensor_py_operators):
def _as_TensorVariable(self):
from .basic_ops import host_from_gpu
return host_from_gpu(self)
def _as_GpuArrayVariable(self, context_name):
if self.type.context_name == context_name:
return self
......@@ -657,6 +652,13 @@ class _operators(_tensor_py_operators):
return GpuToGpu(context_name)(self)
@aet._as_tensor_variable.register(_operators)
def _as_tensor_operators(x, **kwargs):
from aesara.gpuarray.basic_ops import host_from_gpu
return host_from_gpu(x)
class GpuArrayVariable(_operators, Variable):
"""
A variable representing a computation on a certain GPU.
......
......@@ -166,33 +166,33 @@ class IfElse(_NoPythonOp):
return out_shapes
def make_node(self, c, *args):
assert (
len(args) == 2 * self.n_outs
), f"Wrong number of arguments to make_node: expected {int(2 * self.n_outs)}, got {len(args)}"
if len(args) != 2 * self.n_outs:
raise ValueError(
f"Wrong number of arguments to make_node: expected "
f"{int(2 * self.n_outs)}, got {len(args)}"
)
c = aet.basic.as_tensor_variable(c)
if not self.gpu:
# When gpu is true, we are given only gpuarrays, and we want
# to keep them as gpuarrays
nw_args = []
for x in args:
if hasattr(x, "_as_TensorVariable"):
nw_args.append(x._as_TensorVariable())
elif isinstance(x, Variable):
if isinstance(x, Variable):
nw_args.append(x)
else:
nw_args.append(aet.basic.as_tensor_variable(x))
nw_args.append(aet.as_tensor_variable(x))
args = nw_args
aes = args[: self.n_outs]
fs = args[self.n_outs :]
for t, f in zip(aes, fs):
# TODO: Attempt to convert types so that they match?
# new_f = t.type.filter_variable(f)
if t.type != f.type:
raise TypeError(
("IfElse requires same types for true and " "false return values"),
t,
f,
t.type,
f.type,
"IfElse requires same types for true and false return values: "
f"true_branch={t.type}, false_branch={f.type}"
)
if c.ndim > 0:
raise TypeError(
......
......@@ -4,6 +4,41 @@
__docformat__ = "restructuredtext en"
import warnings
from functools import singledispatch
def as_tensor_variable(x, name=None, ndim=None, **kwargs):
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to
turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements.
Parameters
----------
x : Apply or Variable or numpy.ndarray or number
This thing will be transformed into a `Variable` in a sensible way. An
ndarray argument will not be copied, but a list of numbers will be
copied to make an ndarray.
name : str or None
If a new `Variable` instance is created, it will be named with this
string.
ndim : None or integer
Return a Variable with this many dimensions.
Raises
------
TypeError
If `x` cannot be converted to a TensorType Variable.
"""
return _as_tensor_variable(x, name, ndim, **kwargs)
@singledispatch
def _as_tensor_variable(x, name, ndim, **kwargs):
raise NotImplementedError("")
import aesara.tensor.exceptions
from aesara.gradient import consider_constant, grad, hessian, jacobian
......
......@@ -10,6 +10,7 @@ import logging
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from numbers import Number
import numpy as np
......@@ -26,6 +27,8 @@ from aesara.graph.type import CType
from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint
from aesara.scalar import int32
from aesara.scalar.basic import ScalarConstant, ScalarVariable
from aesara.tensor import _as_tensor_variable, as_tensor_variable
from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from aesara.tensor.shape import (
......@@ -82,123 +85,111 @@ def __oplist_tag(thing, tag):
thing.__oplist_tags = tags
def as_tensor_variable(x, name=None, ndim=None):
"""Convert `x` into the appropriate `TensorType`.
@_as_tensor_variable.register(Apply)
def _as_tensor_Apply(x, name, ndim):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise TypeError(
"Multi-output Op encountered. "
"Retry using only one of the outputs directly."
)
This function is often used by `make_node` methods of `Op` subclasses to
turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements.
x = x.default_output()
Parameters
----------
x : Apply or Variable or numpy.ndarray or number
This thing will be transformed into a `Variable` in a sensible way. An
ndarray argument will not be copied, but a list of numbers will be
copied to make an ndarray.
name : str or None
If a new `Variable` instance is created, it will be named with this
string.
ndim : None or integer
Return a Variable with this many dimensions.
return as_tensor_variable(x, name=name, ndim=ndim)
Raises
------
TypeError
If `x` cannot be converted to a TensorType Variable.
"""
if (
isinstance(getattr(x, "type", None), TensorType)
and (name is None or x.name == name)
and (ndim is None or x.ndim == ndim)
):
return x
@_as_tensor_variable.register(ScalarVariable)
@_as_tensor_variable.register(ScalarConstant)
def _as_tensor_Scalar(x, name, ndim):
return as_tensor_variable(tensor_from_scalar(x), name=name, ndim=ndim)
if hasattr(x, "_as_TensorVariable"):
return x._as_TensorVariable() # TODO: pass name and ndim arguments
if isinstance(x, Apply):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise TypeError(
"Multi-output Op encountered. "
"Retry using only one of the outputs directly."
)
@_as_tensor_variable.register(Variable)
def _as_tensor_Variable(x, name, ndim):
if not isinstance(x.type, TensorType):
raise TypeError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
)
x = x.default_output()
if ndim is None:
return x
if isinstance(x, Variable):
if x.type.ndim > ndim:
# strip off leading broadcastable dimensions
first_non_broadcastable = [
idx for idx in range(x.ndim) if not x.broadcastable[idx]
][0]
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
"Tensor of type {} could not be cast to have {} dimensions".format(
x.type, ndim
)
)
return x
elif x.type.ndim < ndim:
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
return x
if isinstance(x, Constant):
return as_tensor_variable(x.data, name=name, ndim=ndim)
if isinstance(x.type, aes.Scalar):
x = tensor_from_scalar(x)
@_as_tensor_variable.register(list)
@_as_tensor_variable.register(tuple)
def _as_tensor_Sequence(x, name, ndim):
if not isinstance(x.type, TensorType):
raise TypeError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
)
if len(x) == 0:
return constant(x, name=name, ndim=ndim)
if ndim is None:
return x
else:
if x.type.ndim > ndim:
# strip off leading broadcastable dimensions
first_non_broadcastable = [
idx for idx in range(x.ndim) if not x.broadcastable[idx]
][0]
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
"Tensor of type {} could not be cast to have {} dimensions".format(
x.type, ndim
)
)
return x
elif x.type.ndim < ndim:
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
# If a sequence has `Variable`s in it, then we want
# to customize the conversion to a tensor type.
def extract_constants(i):
if isinstance(i, Variable):
if isinstance(i, Constant):
return i.data
else:
return x
elif isinstance(x, Sequence):
raise TypeError
else:
return i
def extract_constants(i):
if isinstance(i, Variable):
if isinstance(i, Constant):
return i.data
else:
raise TypeError
else:
return i
try:
x = type(x)(extract_constants(i) for i in x)
except TypeError:
if builtins.all(getattr(i, "ndim", None) == 0 for i in x) and (
ndim is None or ndim == 1
):
# In this instance, we have a sequence of constants with which we
# want to construct a vector, so we can use `MakeVector` directly.
dtype = aes.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return MakeVector(dtype)(*x)
try:
x = [extract_constants(i) for i in x]
except TypeError:
if builtins.all(getattr(i, "ndim", None) == 0 for i in x) and (
ndim is None or ndim == 1
):
# In this instance, we can avoid making a `Join` `Op`, because
# we know that the result should be a vector.
# `MakeVector` is a better option due to its `get_scalar_constant_value`
# support.
dtype = aes.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return MakeVector(dtype)(*x)
# In this case, we have at least one non-`Constant` term, so we
# couldn't get an underlying non-symbolic sequence of objects and we to
# symbolically join terms.
return stack(x)
return stack(x)
return constant(x, name=name, ndim=ndim)
elif isinstance(x, bool):
raise TypeError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, "
"use aesara.tensor.eq(v, w) instead."
)
@_as_tensor_variable.register(np.bool_)
@_as_tensor_variable.register(np.number)
@_as_tensor_variable.register(Number)
@_as_tensor_variable.register(np.ndarray)
def _as_tensor_numbers(x, name, ndim):
return constant(x, name=name, ndim=ndim)
@_as_tensor_variable.register(bool)
def _as_tensor_bool(x, name, ndim):
raise TypeError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, "
"use aesara.tensor.eq(v, w) instead."
)
as_tensor = as_tensor_variable
......@@ -347,6 +338,7 @@ def get_scalar_constant_value(
data = v.tag.unique_value
else:
data = v.data
if isinstance(data, np.ndarray):
return numpy_scalar(data).copy()
else:
......
......@@ -241,9 +241,6 @@ class TensorType(CType):
and dtype and have "compatible" broadcastable pattern.
"""
if hasattr(other, "_as_TensorVariable"):
other = other._as_TensorVariable()
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
......
......@@ -522,6 +522,14 @@ class TestAsTensorVariable:
res = aet.as_tensor(y)
assert isinstance(res.owner.op, MakeVector)
def test_multi_out(self):
class TestOp(Op):
def make_node(self, a, b):
return Apply(self, [a, b], [a, b])
with pytest.raises(TypeError):
aet.as_tensor(TestOp(matrix(), matrix()))
class TestAlloc:
dtype = config.floatX
......@@ -3049,7 +3057,7 @@ def test_dimshuffle_duplicate():
class TestGetScalarConstantValue:
def test_get_scalar_constant_value(self):
def test_basic(self):
a = aet.stack([1, 2, 3])
assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2
......
......@@ -35,6 +35,21 @@ class TestIfelse(utt.OptimizationTestMixin):
else:
return IfElse(n, as_view=True)
def test_wrong_n_outs(self):
x = vector("x", dtype=self.dtype)
c = iscalar("c")
with pytest.raises(ValueError):
IfElse(0)(c, x, x)
def test_const_Op_argument(self):
x = vector("x", dtype=self.dtype)
y = np.array([2.0, 3.0], dtype=self.dtype)
c = iscalar("c")
f = function([c, x], IfElse(1)(c, x, y), mode=self.mode)
val = f(0, np.r_[1.0, 2.0].astype(self.dtype))
assert np.array_equal(val, y)
def test_lazy_if(self):
# Tests that lazy if works .. even if the two results have different
# shapes but the same type (i.e. both vectors, or matrices or
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论