提交 bd54469c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert as_tensor_variable to a singledispatch function

上级 d374dcc6
...@@ -643,11 +643,6 @@ EQ_MAP.update(list((v, k) for k, v in EQ_MAP.items())) ...@@ -643,11 +643,6 @@ EQ_MAP.update(list((v, k) for k, v in EQ_MAP.items()))
class _operators(_tensor_py_operators): class _operators(_tensor_py_operators):
def _as_TensorVariable(self):
from .basic_ops import host_from_gpu
return host_from_gpu(self)
def _as_GpuArrayVariable(self, context_name): def _as_GpuArrayVariable(self, context_name):
if self.type.context_name == context_name: if self.type.context_name == context_name:
return self return self
...@@ -657,6 +652,13 @@ class _operators(_tensor_py_operators): ...@@ -657,6 +652,13 @@ class _operators(_tensor_py_operators):
return GpuToGpu(context_name)(self) return GpuToGpu(context_name)(self)
@aet._as_tensor_variable.register(_operators)
def _as_tensor_operators(x, **kwargs):
from aesara.gpuarray.basic_ops import host_from_gpu
return host_from_gpu(x)
class GpuArrayVariable(_operators, Variable): class GpuArrayVariable(_operators, Variable):
""" """
A variable representing a computation on a certain GPU. A variable representing a computation on a certain GPU.
......
...@@ -166,33 +166,33 @@ class IfElse(_NoPythonOp): ...@@ -166,33 +166,33 @@ class IfElse(_NoPythonOp):
return out_shapes return out_shapes
def make_node(self, c, *args): def make_node(self, c, *args):
assert ( if len(args) != 2 * self.n_outs:
len(args) == 2 * self.n_outs raise ValueError(
), f"Wrong number of arguments to make_node: expected {int(2 * self.n_outs)}, got {len(args)}" f"Wrong number of arguments to make_node: expected "
f"{int(2 * self.n_outs)}, got {len(args)}"
)
c = aet.basic.as_tensor_variable(c) c = aet.basic.as_tensor_variable(c)
if not self.gpu: if not self.gpu:
# When gpu is true, we are given only gpuarrays, and we want # When gpu is true, we are given only gpuarrays, and we want
# to keep them as gpuarrays # to keep them as gpuarrays
nw_args = [] nw_args = []
for x in args: for x in args:
if hasattr(x, "_as_TensorVariable"): if isinstance(x, Variable):
nw_args.append(x._as_TensorVariable())
elif isinstance(x, Variable):
nw_args.append(x) nw_args.append(x)
else: else:
nw_args.append(aet.basic.as_tensor_variable(x)) nw_args.append(aet.as_tensor_variable(x))
args = nw_args args = nw_args
aes = args[: self.n_outs] aes = args[: self.n_outs]
fs = args[self.n_outs :] fs = args[self.n_outs :]
for t, f in zip(aes, fs): for t, f in zip(aes, fs):
# TODO: Attempt to convert types so that they match?
# new_f = t.type.filter_variable(f)
if t.type != f.type: if t.type != f.type:
raise TypeError( raise TypeError(
("IfElse requires same types for true and " "false return values"), "IfElse requires same types for true and false return values: "
t, f"true_branch={t.type}, false_branch={f.type}"
f,
t.type,
f.type,
) )
if c.ndim > 0: if c.ndim > 0:
raise TypeError( raise TypeError(
......
...@@ -4,6 +4,41 @@ ...@@ -4,6 +4,41 @@
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import warnings import warnings
from functools import singledispatch
def as_tensor_variable(x, name=None, ndim=None, **kwargs):
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to
turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements.
Parameters
----------
x : Apply or Variable or numpy.ndarray or number
This thing will be transformed into a `Variable` in a sensible way. An
ndarray argument will not be copied, but a list of numbers will be
copied to make an ndarray.
name : str or None
If a new `Variable` instance is created, it will be named with this
string.
ndim : None or integer
Return a Variable with this many dimensions.
Raises
------
TypeError
If `x` cannot be converted to a TensorType Variable.
"""
return _as_tensor_variable(x, name, ndim, **kwargs)
@singledispatch
def _as_tensor_variable(x, name, ndim, **kwargs):
raise NotImplementedError("")
import aesara.tensor.exceptions import aesara.tensor.exceptions
from aesara.gradient import consider_constant, grad, hessian, jacobian from aesara.gradient import consider_constant, grad, hessian, jacobian
......
...@@ -10,6 +10,7 @@ import logging ...@@ -10,6 +10,7 @@ import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from numbers import Number
import numpy as np import numpy as np
...@@ -26,6 +27,8 @@ from aesara.graph.type import CType ...@@ -26,6 +27,8 @@ from aesara.graph.type import CType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint from aesara.printing import min_informative_str, pprint
from aesara.scalar import int32 from aesara.scalar import int32
from aesara.scalar.basic import ScalarConstant, ScalarVariable
from aesara.tensor import _as_tensor_variable, as_tensor_variable
from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from aesara.tensor.shape import ( from aesara.tensor.shape import (
...@@ -82,123 +85,111 @@ def __oplist_tag(thing, tag): ...@@ -82,123 +85,111 @@ def __oplist_tag(thing, tag):
thing.__oplist_tags = tags thing.__oplist_tags = tags
def as_tensor_variable(x, name=None, ndim=None): @_as_tensor_variable.register(Apply)
"""Convert `x` into the appropriate `TensorType`. def _as_tensor_Apply(x, name, ndim):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise TypeError(
"Multi-output Op encountered. "
"Retry using only one of the outputs directly."
)
This function is often used by `make_node` methods of `Op` subclasses to x = x.default_output()
turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements.
Parameters return as_tensor_variable(x, name=name, ndim=ndim)
----------
x : Apply or Variable or numpy.ndarray or number
This thing will be transformed into a `Variable` in a sensible way. An
ndarray argument will not be copied, but a list of numbers will be
copied to make an ndarray.
name : str or None
If a new `Variable` instance is created, it will be named with this
string.
ndim : None or integer
Return a Variable with this many dimensions.
Raises
------
TypeError
If `x` cannot be converted to a TensorType Variable.
""" @_as_tensor_variable.register(ScalarVariable)
if ( @_as_tensor_variable.register(ScalarConstant)
isinstance(getattr(x, "type", None), TensorType) def _as_tensor_Scalar(x, name, ndim):
and (name is None or x.name == name) return as_tensor_variable(tensor_from_scalar(x), name=name, ndim=ndim)
and (ndim is None or x.ndim == ndim)
):
return x
if hasattr(x, "_as_TensorVariable"):
return x._as_TensorVariable() # TODO: pass name and ndim arguments
if isinstance(x, Apply): @_as_tensor_variable.register(Variable)
# use Apply's default output mechanism def _as_tensor_Variable(x, name, ndim):
if (x.op.default_output is None) and (len(x.outputs) != 1): if not isinstance(x.type, TensorType):
raise TypeError( raise TypeError(
"Multi-output Op encountered. " "Tensor type field must be a TensorType; found {}.".format(type(x.type))
"Retry using only one of the outputs directly." )
)
x = x.default_output() if ndim is None:
return x
if isinstance(x, Variable): if x.type.ndim > ndim:
# strip off leading broadcastable dimensions
first_non_broadcastable = [
idx for idx in range(x.ndim) if not x.broadcastable[idx]
][0]
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
"Tensor of type {} could not be cast to have {} dimensions".format(
x.type, ndim
)
)
return x
elif x.type.ndim < ndim:
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
return x
if isinstance(x, Constant):
return as_tensor_variable(x.data, name=name, ndim=ndim)
if isinstance(x.type, aes.Scalar): @_as_tensor_variable.register(list)
x = tensor_from_scalar(x) @_as_tensor_variable.register(tuple)
def _as_tensor_Sequence(x, name, ndim):
if not isinstance(x.type, TensorType): if len(x) == 0:
raise TypeError( return constant(x, name=name, ndim=ndim)
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
)
if ndim is None: # If a sequence has `Variable`s in it, then we want
return x # to customize the conversion to a tensor type.
else: def extract_constants(i):
if x.type.ndim > ndim: if isinstance(i, Variable):
# strip off leading broadcastable dimensions if isinstance(i, Constant):
first_non_broadcastable = [ return i.data
idx for idx in range(x.ndim) if not x.broadcastable[idx]
][0]
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
"Tensor of type {} could not be cast to have {} dimensions".format(
x.type, ndim
)
)
return x
elif x.type.ndim < ndim:
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else: else:
return x raise TypeError
else:
elif isinstance(x, Sequence): return i
def extract_constants(i): try:
if isinstance(i, Variable): x = type(x)(extract_constants(i) for i in x)
if isinstance(i, Constant): except TypeError:
return i.data if builtins.all(getattr(i, "ndim", None) == 0 for i in x) and (
else: ndim is None or ndim == 1
raise TypeError ):
else: # In this instance, we have a sequence of constants with which we
return i # want to construct a vector, so we can use `MakeVector` directly.
dtype = aes.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return MakeVector(dtype)(*x)
try: # In this case, we have at least one non-`Constant` term, so we
x = [extract_constants(i) for i in x] # couldn't get an underlying non-symbolic sequence of objects and we to
except TypeError: # symbolically join terms.
if builtins.all(getattr(i, "ndim", None) == 0 for i in x) and ( return stack(x)
ndim is None or ndim == 1
):
# In this instance, we can avoid making a `Join` `Op`, because
# we know that the result should be a vector.
# `MakeVector` is a better option due to its `get_scalar_constant_value`
# support.
dtype = aes.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return MakeVector(dtype)(*x)
return stack(x) return constant(x, name=name, ndim=ndim)
elif isinstance(x, bool):
raise TypeError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, "
"use aesara.tensor.eq(v, w) instead."
)
@_as_tensor_variable.register(np.bool_)
@_as_tensor_variable.register(np.number)
@_as_tensor_variable.register(Number)
@_as_tensor_variable.register(np.ndarray)
def _as_tensor_numbers(x, name, ndim):
return constant(x, name=name, ndim=ndim) return constant(x, name=name, ndim=ndim)
@_as_tensor_variable.register(bool)
def _as_tensor_bool(x, name, ndim):
raise TypeError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, "
"use aesara.tensor.eq(v, w) instead."
)
as_tensor = as_tensor_variable as_tensor = as_tensor_variable
...@@ -347,6 +338,7 @@ def get_scalar_constant_value( ...@@ -347,6 +338,7 @@ def get_scalar_constant_value(
data = v.tag.unique_value data = v.tag.unique_value
else: else:
data = v.data data = v.data
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
return numpy_scalar(data).copy() return numpy_scalar(data).copy()
else: else:
......
...@@ -241,9 +241,6 @@ class TensorType(CType): ...@@ -241,9 +241,6 @@ class TensorType(CType):
and dtype and have "compatible" broadcastable pattern. and dtype and have "compatible" broadcastable pattern.
""" """
if hasattr(other, "_as_TensorVariable"):
other = other._as_TensorVariable()
if not isinstance(other, Variable): if not isinstance(other, Variable):
# The value is not a Variable: we cast it into # The value is not a Variable: we cast it into
# a Constant of the appropriate Type. # a Constant of the appropriate Type.
......
...@@ -522,6 +522,14 @@ class TestAsTensorVariable: ...@@ -522,6 +522,14 @@ class TestAsTensorVariable:
res = aet.as_tensor(y) res = aet.as_tensor(y)
assert isinstance(res.owner.op, MakeVector) assert isinstance(res.owner.op, MakeVector)
def test_multi_out(self):
class TestOp(Op):
def make_node(self, a, b):
return Apply(self, [a, b], [a, b])
with pytest.raises(TypeError):
aet.as_tensor(TestOp(matrix(), matrix()))
class TestAlloc: class TestAlloc:
dtype = config.floatX dtype = config.floatX
...@@ -3049,7 +3057,7 @@ def test_dimshuffle_duplicate(): ...@@ -3049,7 +3057,7 @@ def test_dimshuffle_duplicate():
class TestGetScalarConstantValue: class TestGetScalarConstantValue:
def test_get_scalar_constant_value(self): def test_basic(self):
a = aet.stack([1, 2, 3]) a = aet.stack([1, 2, 3])
assert get_scalar_constant_value(a[0]) == 1 assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2 assert get_scalar_constant_value(a[1]) == 2
......
...@@ -35,6 +35,21 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -35,6 +35,21 @@ class TestIfelse(utt.OptimizationTestMixin):
else: else:
return IfElse(n, as_view=True) return IfElse(n, as_view=True)
def test_wrong_n_outs(self):
x = vector("x", dtype=self.dtype)
c = iscalar("c")
with pytest.raises(ValueError):
IfElse(0)(c, x, x)
def test_const_Op_argument(self):
x = vector("x", dtype=self.dtype)
y = np.array([2.0, 3.0], dtype=self.dtype)
c = iscalar("c")
f = function([c, x], IfElse(1)(c, x, y), mode=self.mode)
val = f(0, np.r_[1.0, 2.0].astype(self.dtype))
assert np.array_equal(val, y)
def test_lazy_if(self): def test_lazy_if(self):
# Tests that lazy if works .. even if the two results have different # Tests that lazy if works .. even if the two results have different
# shapes but the same type (i.e. both vectors, or matrices or # shapes but the same type (i.e. both vectors, or matrices or
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论