提交 eeaa94d6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow RandomVariable float dtypes to be set dynamically

上级 3707d812
...@@ -35,14 +35,14 @@ def test_default_shape_from_params(): ...@@ -35,14 +35,14 @@ def test_default_shape_from_params():
assert res == (3, 4) assert res == (3, 4)
def test_RandomVariable(): def test_RandomVariable_basics():
str_res = str( str_res = str(
RandomVariable( RandomVariable(
"normal", "normal",
0, 0,
[0, 0], [0, 0],
"normal", config.floatX,
inplace=True, inplace=True,
) )
) )
...@@ -130,6 +130,25 @@ def test_RandomVariable(): ...@@ -130,6 +130,25 @@ def test_RandomVariable():
assert res == [False] * 3 assert res == [False] * 3
def test_RandomVariable_floatX():
test_rv_op = RandomVariable(
"normal",
0,
[0, 0],
"floatX",
inplace=True,
)
assert test_rv_op.dtype == "floatX"
assert test_rv_op(0, 1).dtype == config.floatX
new_floatX = "float64" if config.floatX == "float32" else "float32"
with change_flags(floatX=new_floatX):
assert test_rv_op(0, 1).dtype == new_floatX
def test_observed(): def test_observed():
rv_var = normal(0, 1, size=3) rv_var = normal(0, 1, size=3)
obs_var = observed(rv_var, np.array([0.2, 0.1, -2.4], dtype=config.floatX)) obs_var = observed(rv_var, np.array([0.2, 0.1, -2.4], dtype=config.floatX))
......
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import scipy.stats as stats import scipy.stats as stats
import theano import theano
from theano.configdefaults import config
from theano.tensor.basic import as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.random.op import RandomVariable, default_shape_from_params from theano.tensor.random.op import RandomVariable, default_shape_from_params
from theano.tensor.random.utils import broadcast_params from theano.tensor.random.utils import broadcast_params
...@@ -20,7 +19,7 @@ class UniformRV(RandomVariable): ...@@ -20,7 +19,7 @@ class UniformRV(RandomVariable):
name = "uniform" name = "uniform"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("U", "\\operatorname{U}") _print_name = ("U", "\\operatorname{U}")
def __call__(self, low=0.0, high=1.0, size=None, **kwargs): def __call__(self, low=0.0, high=1.0, size=None, **kwargs):
...@@ -34,7 +33,7 @@ class BetaRV(RandomVariable): ...@@ -34,7 +33,7 @@ class BetaRV(RandomVariable):
name = "beta" name = "beta"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("Beta", "\\operatorname{Beta}") _print_name = ("Beta", "\\operatorname{Beta}")
...@@ -45,7 +44,7 @@ class NormalRV(RandomVariable): ...@@ -45,7 +44,7 @@ class NormalRV(RandomVariable):
name = "normal" name = "normal"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("N", "\\operatorname{N}") _print_name = ("N", "\\operatorname{N}")
def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
...@@ -59,7 +58,7 @@ class HalfNormalRV(RandomVariable): ...@@ -59,7 +58,7 @@ class HalfNormalRV(RandomVariable):
name = "halfnormal" name = "halfnormal"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("N**+", "\\operatorname{N^{+}}") _print_name = ("N**+", "\\operatorname{N^{+}}")
def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
...@@ -77,7 +76,7 @@ class GammaRV(RandomVariable): ...@@ -77,7 +76,7 @@ class GammaRV(RandomVariable):
name = "halfnormal" name = "halfnormal"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("Gamma", "\\operatorname{Gamma}") _print_name = ("Gamma", "\\operatorname{Gamma}")
def __call__(self, shape, rate, size=None, **kwargs): def __call__(self, shape, rate, size=None, **kwargs):
...@@ -95,7 +94,7 @@ class ExponentialRV(RandomVariable): ...@@ -95,7 +94,7 @@ class ExponentialRV(RandomVariable):
name = "exponential" name = "exponential"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0] ndims_params = [0]
dtype = config.floatX dtype = "floatX"
_print_name = ("Exp", "\\operatorname{Exp}") _print_name = ("Exp", "\\operatorname{Exp}")
def __call__(self, scale=1.0, size=None, **kwargs): def __call__(self, scale=1.0, size=None, **kwargs):
...@@ -130,14 +129,17 @@ class MvNormalRV(RandomVariable): ...@@ -130,14 +129,17 @@ class MvNormalRV(RandomVariable):
name = "multivariate_normal" name = "multivariate_normal"
ndim_supp = 1 ndim_supp = 1
ndims_params = [1, 2] ndims_params = [1, 2]
dtype = config.floatX dtype = "floatX"
_print_name = ("N", "\\operatorname{N}") _print_name = ("N", "\\operatorname{N}")
def __call__(self, mean=None, cov=None, size=None, **kwargs): def __call__(self, mean=None, cov=None, size=None, **kwargs):
dtype = theano.config.floatX if self.dtype == "floatX" else self.dtype
if mean is None: if mean is None:
mean = np.array([0.0], dtype=self.dtype) mean = np.array([0.0], dtype=dtype)
if cov is None: if cov is None:
cov = np.array([[1.0]], dtype=self.dtype) cov = np.array([[1.0]], dtype=dtype)
return super().__call__(mean, cov, size=size, **kwargs) return super().__call__(mean, cov, size=size, **kwargs)
@classmethod @classmethod
...@@ -171,7 +173,7 @@ class DirichletRV(RandomVariable): ...@@ -171,7 +173,7 @@ class DirichletRV(RandomVariable):
name = "dirichlet" name = "dirichlet"
ndim_supp = 1 ndim_supp = 1
ndims_params = [1] ndims_params = [1]
dtype = config.floatX dtype = "floatX"
_print_name = ("Dir", "\\operatorname{Dir}") _print_name = ("Dir", "\\operatorname{Dir}")
@classmethod @classmethod
...@@ -209,7 +211,7 @@ class CauchyRV(RandomVariable): ...@@ -209,7 +211,7 @@ class CauchyRV(RandomVariable):
name = "cauchy" name = "cauchy"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("C", "\\operatorname{C}") _print_name = ("C", "\\operatorname{C}")
def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
...@@ -227,7 +229,7 @@ class HalfCauchyRV(RandomVariable): ...@@ -227,7 +229,7 @@ class HalfCauchyRV(RandomVariable):
name = "cauchy" name = "cauchy"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("C**+", "\\operatorname{C^{+}}") _print_name = ("C**+", "\\operatorname{C^{+}}")
def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
...@@ -245,7 +247,7 @@ class InvGammaRV(RandomVariable): ...@@ -245,7 +247,7 @@ class InvGammaRV(RandomVariable):
name = "invgamma" name = "invgamma"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("InvGamma", "\\operatorname{Gamma^{-1}}") _print_name = ("InvGamma", "\\operatorname{Gamma^{-1}}")
@classmethod @classmethod
...@@ -260,7 +262,7 @@ class TruncExponentialRV(RandomVariable): ...@@ -260,7 +262,7 @@ class TruncExponentialRV(RandomVariable):
name = "truncexpon" name = "truncexpon"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0, 0] ndims_params = [0, 0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("TruncExp", "\\operatorname{TruncExp}") _print_name = ("TruncExp", "\\operatorname{TruncExp}")
@classmethod @classmethod
...@@ -392,7 +394,7 @@ class PolyaGammaRV(RandomVariable): ...@@ -392,7 +394,7 @@ class PolyaGammaRV(RandomVariable):
name = "polya-gamma" name = "polya-gamma"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
dtype = config.floatX dtype = "floatX"
_print_name = ("PG", "\\operatorname{PG}") _print_name = ("PG", "\\operatorname{PG}")
@classmethod @classmethod
......
...@@ -110,16 +110,17 @@ class RandomVariable(Op): ...@@ -110,16 +110,17 @@ class RandomVariable(Op):
The `Op`'s display name. The `Op`'s display name.
ndim_supp: int ndim_supp: int
Total number of dimensions for a single draw of the random variable Total number of dimensions for a single draw of the random variable
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`). (e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
ndims_params: list of int ndims_params: list of int
Number of dimensions for each distribution parameter when the Number of dimensions for each distribution parameter when the
parameters only specify a single drawn of the random variable (e.g. a parameters only specify a single drawn of the random variable
multivariate normal's mean is 1D and covariance is 2D, so `ndims_params (e.g. a multivariate normal's mean is 1D and covariance is 2D, so
= [1, 2]`). ``ndims_params = [1, 2]``).
dtype: Theano dtype (optional) dtype: str (optional)
The dtype of the sampled output(s). If `None` (the default), the The dtype of the sampled output. If the value ``"floatX"`` is
`dtype` keyword must be set when `RandomVariable.make_node` is given, then ``dtype`` is set to ``theano.config.floatX``. If
called. ``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
inplace: boolean (optional) inplace: boolean (optional)
Determine whether or not the underlying rng state is updated Determine whether or not the underlying rng state is updated
in-place or not (i.e. copied). in-place or not (i.e. copied).
...@@ -135,6 +136,7 @@ class RandomVariable(Op): ...@@ -135,6 +136,7 @@ class RandomVariable(Op):
ndims_params if ndims_params is not None else getattr(self, "ndims_params") ndims_params if ndims_params is not None else getattr(self, "ndims_params")
) )
self.dtype = dtype or getattr(self, "dtype", None) self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = ( self.inplace = (
inplace if inplace is not None else getattr(self, "inplace", False) inplace if inplace is not None else getattr(self, "inplace", False)
) )
...@@ -333,9 +335,10 @@ class RandomVariable(Op): ...@@ -333,9 +335,10 @@ class RandomVariable(Op):
new one, if `None`. new one, if `None`.
size: int or Sequence size: int or Sequence
Numpy-like size of the output (i.e. replications). Numpy-like size of the output (i.e. replications).
dtype: Theano dtype dtype: str
The dtype of the sampled output. This value is only used when The dtype of the sampled output. If the value ``"floatX"`` is
`self.dtype` isn't set. given, then ``dtype`` is set to ``theano.config.floatX``. This
value is only used when `self.dtype` isn't set.
dist_params: list dist_params: list
Distribution parameters. Distribution parameters.
...@@ -372,7 +375,9 @@ class RandomVariable(Op): ...@@ -372,7 +375,9 @@ class RandomVariable(Op):
bcast = self.compute_bcast(dist_params, size) bcast = self.compute_bcast(dist_params, size)
dtype = self.dtype or dtype dtype = self.dtype or dtype
if dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes): if dtype == "floatX":
dtype = config.floatX
elif dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes):
# dtype = tt.scal.upcast(self.dtype, *[p.dtype for p in dist_params]) # dtype = tt.scal.upcast(self.dtype, *[p.dtype for p in dist_params])
raise TypeError("dtype is unspecified") raise TypeError("dtype is unspecified")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论