提交 98d73d78 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove RandomVariable dtype input

上级 df32683c
...@@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): ...@@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
if None in static_size: if None in static_size:
assert_size_argument_jax_compatible(node) assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, *parameters):
# PyTensor uses empty size to represent size = None # PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,): if jax.numpy.asarray(size).shape == (0,):
size = None size = None
...@@ -122,7 +122,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): ...@@ -122,7 +122,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
else: else:
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)( return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters rng, static_size, out_dtype, *parameters
) )
......
...@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func): ...@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func):
"size_dims", "size_dims",
"rng", "rng",
"size", "size",
"dtype",
], ],
suffix_sep="_", suffix_sep="_",
) )
...@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}): ...@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
) )
random_fn_input_names = ", ".join( random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in dist_params] ["rng", "size"] + [unique_names(i) for i in dist_params]
) )
# Now, create a Numba JITable function that implements the `size` parameter # Now, create a Numba JITable function that implements the `size` parameter
...@@ -243,7 +242,7 @@ def create_numba_random_fn( ...@@ -243,7 +242,7 @@ def create_numba_random_fn(
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
unique_names = unique_name_generator( unique_names = unique_name_generator(
[np_random_fn_name, *np_global_env.keys(), "rng", "size", "dtype"], [np_random_fn_name, *np_global_env.keys(), "rng", "size"],
suffix_sep="_", suffix_sep="_",
) )
...@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs): ...@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
p_ndim = node.inputs[-1].ndim p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit @numba_basic.numba_njit
def categorical_rv(rng, size, dtype, p): def categorical_rv(rng, size, p):
if not size_len: if not size_len:
size_tpl = p.shape[:-1] size_tpl = p.shape[:-1]
else: else:
...@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs): ...@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
if alphas_ndim > 1: if alphas_ndim > 1:
@numba_basic.numba_njit @numba_basic.numba_njit
def dirichlet_rv(rng, size, dtype, alphas): def dirichlet_rv(rng, size, alphas):
if size_len > 0: if size_len > 0:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if ( if (
...@@ -365,7 +364,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs): ...@@ -365,7 +364,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
else: else:
@numba_basic.numba_njit @numba_basic.numba_njit
def dirichlet_rv(rng, size, dtype, alphas): def dirichlet_rv(rng, size, alphas):
size = numba_ndarray.to_fixed_tuple(size, size_len) size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.dirichlet(alphas, size)) return (rng, np.random.dirichlet(alphas, size))
...@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs): ...@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
if op.has_p_param: if op.has_p_param:
@numba_basic.numba_njit @numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape): def choice_without_replacement_rv(rng, size, a, p, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False, p=p) samples = np.random.choice(a, size=core_shape, replace=False, p=p)
return (rng, samples) return (rng, samples)
else: else:
@numba_basic.numba_njit @numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, dtype, a, core_shape): def choice_without_replacement_rv(rng, size, a, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len) core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False) samples = np.random.choice(a, size=core_shape, replace=False)
return (rng, samples) return (rng, samples)
...@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs): ...@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
@numba_basic.numba_njit @numba_basic.numba_njit
def permutation_rv(rng, size, dtype, x): def permutation_rv(rng, size, x):
if batch_ndim: if batch_ndim:
x_core_shape = x.shape[x_batch_ndim:] x_core_shape = x.shape[x_batch_ndim:]
if size_is_none: if size_is_none:
......
...@@ -27,7 +27,7 @@ from pytensor.tensor.random.utils import ( ...@@ -27,7 +27,7 @@ from pytensor.tensor.random.utils import (
normalize_size_param, normalize_size_param,
) )
from pytensor.tensor.shape import shape_tuple from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -65,7 +65,7 @@ class RandomVariable(Op): ...@@ -65,7 +65,7 @@ class RandomVariable(Op):
signature: str signature: str
Numpy-like vectorized signature of the random variable. Numpy-like vectorized signature of the random variable.
dtype: str (optional) dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is The default dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when ``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called. `RandomVariable.make_node` is called.
...@@ -287,8 +287,8 @@ class RandomVariable(Op): ...@@ -287,8 +287,8 @@ class RandomVariable(Op):
return shape return shape
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs _, size, *dist_params = node.inputs
_, size_shape, _, *param_shapes = input_shapes _, size_shape, *param_shapes = input_shapes
try: try:
size_len = get_vector_length(size) size_len = get_vector_length(size)
...@@ -302,14 +302,34 @@ class RandomVariable(Op): ...@@ -302,14 +302,34 @@ class RandomVariable(Op):
return [None, list(shape)] return [None, list(shape)]
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs): def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
res = super().__call__(rng, size, dtype, *args, **kwargs) if dtype is None:
dtype = self.dtype
if dtype == "floatX":
dtype = config.floatX
# We need to recreate the Op with the right dtype
if dtype != self.dtype:
# Check we are not switching from float to int
if self.dtype is not None:
if dtype.startswith("float") != self.dtype.startswith("float"):
raise ValueError(
f"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}"
)
props = self._props_dict()
props["dtype"] = dtype
new_op = type(self)(**props)
return new_op.__call__(
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
)
res = super().__call__(rng, size, *args, **kwargs)
if name is not None: if name is not None:
res.name = name res.name = name
return res return res
def make_node(self, rng, size, dtype, *dist_params): def make_node(self, rng, size, *dist_params):
"""Create a random variable node. """Create a random variable node.
Parameters Parameters
...@@ -349,23 +369,10 @@ class RandomVariable(Op): ...@@ -349,23 +369,10 @@ class RandomVariable(Op):
shape = self._infer_shape(size, dist_params) shape = self._infer_shape(size, dist_params)
_, static_shape = infer_static_shape(shape) _, static_shape = infer_static_shape(shape)
dtype = self.dtype or dtype
if dtype == "floatX": inputs = (rng, size, *dist_params)
dtype = config.floatX out_type = TensorType(dtype=self.dtype, shape=static_shape)
elif dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes): outputs = (rng.type(), out_type())
raise TypeError("dtype is unspecified")
if isinstance(dtype, str):
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
else:
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]
inputs = (rng, size, dtype_idx, *dist_params)
out_var = TensorType(dtype=dtype, shape=static_shape)()
outputs = (rng.type(), out_var)
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -382,14 +389,12 @@ class RandomVariable(Op): ...@@ -382,14 +389,12 @@ class RandomVariable(Op):
def dist_params(self, node) -> Sequence[Variable]: def dist_params(self, node) -> Sequence[Variable]:
"""Return the node inpust corresponding to dist params""" """Return the node inpust corresponding to dist params"""
return node.inputs[3:] return node.inputs[2:]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
rng_var_out, smpl_out = outputs rng_var_out, smpl_out = outputs
rng, size, dtype, *args = inputs rng, size, *args = inputs
out_var = node.outputs[1]
# If `size == []`, that means no size is enforced, and NumPy is trusted # If `size == []`, that means no size is enforced, and NumPy is trusted
# to draw the appropriate number of samples, NumPy uses `size=None` to # to draw the appropriate number of samples, NumPy uses `size=None` to
...@@ -408,11 +413,8 @@ class RandomVariable(Op): ...@@ -408,11 +413,8 @@ class RandomVariable(Op):
smpl_val = self.rng_fn(rng, *([*args, size])) smpl_val = self.rng_fn(rng, *([*args, size]))
if ( if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
not isinstance(smpl_val, np.ndarray) smpl_val = _asarray(smpl_val, dtype=self.dtype)
or str(smpl_val.dtype) != out_var.type.dtype
):
smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)
smpl_out[0] = smpl_val smpl_out[0] = smpl_val
...@@ -463,7 +465,7 @@ default_rng = DefaultGeneratorMakerOp() ...@@ -463,7 +465,7 @@ default_rng = DefaultGeneratorMakerOp()
@_vectorize_node.register(RandomVariable) @_vectorize_node.register(RandomVariable)
def vectorize_random_variable( def vectorize_random_variable(
op: RandomVariable, node: Apply, rng, size, dtype, *dist_params op: RandomVariable, node: Apply, rng, size, *dist_params
) -> Apply: ) -> Apply:
# If size was provided originally and a new size hasn't been provided, # If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions. # We extend it to accommodate the new input batch dimensions.
...@@ -491,4 +493,4 @@ def vectorize_random_variable( ...@@ -491,4 +493,4 @@ def vectorize_random_variable(
new_size_dims = broadcasted_batch_shape[:new_ndim] new_size_dims = broadcasted_batch_shape[:new_ndim]
size = concatenate([new_size_dims, size]) size = concatenate([new_size_dims, size])
return op.make_node(rng, size, dtype, *dist_params) return op.make_node(rng, size, *dist_params)
...@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node): ...@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
if not isinstance(node.op, RandomVariable): if not isinstance(node.op, RandomVariable):
return return
rng, size, dtype, *dist_params = node.inputs rng, size, *dist_params = node.inputs
dist_params = broadcast_params(dist_params, node.op.ndims_params) dist_params = broadcast_params(dist_params, node.op.ndims_params)
...@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node): ...@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
else: else:
return return
new_node = node.op.make_node(rng, None, dtype, *dist_params) new_node = node.op.make_node(rng, None, *dist_params)
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(new_node) compute_test_value(new_node)
...@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return False return False
rv_op = rv_node.op rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs rng, size, *dist_params = rv_node.inputs
rv = rv_node.default_output() rv = rv_node.default_output()
# Check that Dimshuffle does not affect support dims # Check that Dimshuffle does not affect support dims
...@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
) )
new_dist_params.append(param.dimshuffle(param_new_order)) new_dist_params.append(param.dimshuffle(param_new_order))
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) new_node = rv_op.make_node(rng, new_size, *new_dist_params)
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(new_node) compute_test_value(new_node)
...@@ -233,7 +233,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -233,7 +233,7 @@ def local_subtensor_rv_lift(fgraph, node):
return None return None
rv_op = rv_node.op rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs rng, size, *dist_params = rv_node.inputs
# Parse indices # Parse indices
idx_list = getattr(subtensor_op, "idx_list", None) idx_list = getattr(subtensor_op, "idx_list", None)
...@@ -346,7 +346,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -346,7 +346,7 @@ def local_subtensor_rv_lift(fgraph, node):
new_dist_params.append(batch_param[tuple(batch_indices)]) new_dist_params.append(batch_param[tuple(batch_indices)])
# Create new RV # Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) new_node = rv_op.make_node(rng, new_size, *new_dist_params)
new_rv = new_node.default_output() new_rv = new_node.default_output()
copy_stack_trace(rv, new_rv) copy_stack_trace(rv, new_rv)
......
...@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
NormalRV,
categorical, categorical,
dirichlet, dirichlet,
multinomial, multinomial,
...@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
) )
if lifted: if lifted:
assert new_out.owner.op == dist_op assert isinstance(new_out.owner.op, type(dist_op))
assert all( assert all(
isinstance(i.owner.op, DimShuffle) isinstance(i.owner.op, DimShuffle)
for i in new_out.owner.op.dist_params(new_out.owner) for i in new_out.owner.op.dist_params(new_out.owner)
...@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions():
subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert subtensor_node == y.owner assert subtensor_node == y.owner
assert isinstance(subtensor_node.op, Subtensor) assert isinstance(subtensor_node.op, Subtensor)
assert subtensor_node.inputs[0].owner.op == normal assert isinstance(subtensor_node.inputs[0].owner.op, NormalRV)
z = pt.ones(x.shape) - x[1] z = pt.ones(x.shape) - x[1]
...@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions():
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert rv_node.op == normal assert isinstance(rv_node.op, NormalRV)
assert isinstance(rv_node.inputs[-1].owner.op, Subtensor) assert isinstance(rv_node.inputs[-1].owner.op, Subtensor)
assert isinstance(rv_node.inputs[-2].owner.op, Subtensor) assert isinstance(rv_node.inputs[-2].owner.op, Subtensor)
...@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions():
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
assert dimshuffle_node == y.owner assert dimshuffle_node == y.owner
assert isinstance(dimshuffle_node.op, DimShuffle) assert isinstance(dimshuffle_node.op, DimShuffle)
assert dimshuffle_node.inputs[0].owner.op == normal assert isinstance(dimshuffle_node.inputs[0].owner.op, NormalRV)
z = pt.ones(x.shape) - y z = pt.ones(x.shape) - y
...@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions():
EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner rv_node = fg.outputs[0].owner.inputs[1].owner
assert rv_node.op == normal assert isinstance(rv_node.op, NormalRV)
assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle) assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle)
assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle) assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle)
......
...@@ -3,14 +3,14 @@ import pytest ...@@ -3,14 +3,14 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad from pytensor.graph.replace import vectorize_graph
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.math import eq from pytensor.tensor.math import eq
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import all_dtypes, iscalar, tensor from pytensor.tensor.type import iscalar, tensor
@pytest.fixture(scope="function", autouse=False) @pytest.fixture(scope="function", autouse=False)
...@@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags):
inplace=True, inplace=True,
)(0, 1, size={1, 2}) )(0, 1, size={1, 2})
# No dtype
with pytest.raises(TypeError, match="^dtype*"):
RandomVariable(
"normal",
0,
[0, 0],
inplace=True,
)(0, 1)
# Confirm that `inplace` works # Confirm that `inplace` works
rv = RandomVariable( rv = RandomVariable(
"normal", "normal",
...@@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags):
rv_shape = rv._infer_shape(pt.constant([]), (), []) rv_shape = rv._infer_shape(pt.constant([]), (), [])
assert rv_shape.equals(pt.constant([], dtype="int64")) assert rv_shape.equals(pt.constant([], dtype="int64"))
# Integer-specified `dtype` # `dtype` is respected
dtype_1 = all_dtypes[1] rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
rv_node = rv.make_node(None, None, 1) with config.change_flags(compute_test_value="off"):
rv_out = rv_node.outputs[1] rv_out = rv()
rv_out.tag.test_value = 1 assert rv_out.dtype == "int32"
rv_out = rv(dtype="int64")
assert rv_out.dtype == dtype_1 assert rv_out.dtype == "int64"
with pytest.raises(NullTypeGradError): with pytest.raises(
grad(rv_out, [rv_node.inputs[0]]) ValueError,
match="Cannot change the dtype of a normal RV from int32 to float32",
):
assert rv(dtype="float32").dtype == "float32"
def test_RandomVariable_bcast(strict_test_value_flags): def test_RandomVariable_bcast(strict_test_value_flags):
...@@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape(): ...@@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape():
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3) assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)
def test_vectorize_node(): def test_vectorize():
vec = tensor(shape=(None,)) vec = tensor(shape=(None,))
mat = tensor(shape=(None, None)) mat = tensor(shape=(None, None))
# Test without size # Test without size
node = normal(vec).owner out = normal(vec)
new_inputs = node.inputs.copy() vect_node = vectorize_graph(out, {vec: mat}).owner
new_inputs[3] = mat # mu assert isinstance(vect_node.op, NormalRV)
vect_node = vectorize_node(node, *new_inputs) assert vect_node.op.dist_params(vect_node)[0] is mat
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
# Test with size, new size provided # Test with size, new size provided
node = normal(vec, size=(3,)).owner size = pt.as_tensor(np.array((3,), dtype="int64"))
new_inputs = node.inputs.copy() out = normal(vec, size=size)
new_inputs[1] = (2, 3) # size vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner
new_inputs[3] = mat # mu assert isinstance(vect_node.op, NormalRV)
vect_node = vectorize_node(node, *new_inputs) assert tuple(vect_node.op.size_param(vect_node).eval()) == (2, 3)
assert vect_node.op is normal assert vect_node.op.dist_params(vect_node)[0] is mat
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
assert vect_node.inputs[3] is mat
# Test with size, new size not provided # Test with size, new size not provided
node = normal(vec, size=(3,)).owner out = normal(vec, size=(3,))
new_inputs = node.inputs.copy() vect_node = vectorize_graph(out, {vec: mat}).owner
new_inputs[3] = mat # mu assert isinstance(vect_node.op, NormalRV)
vect_node = vectorize_node(node, *new_inputs) assert vect_node.op.dist_params(vect_node)[0] is mat
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
assert tuple( assert tuple(
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)}) vect_node.op.size_param(vect_node).eval(
{mat: np.zeros((2, 3), dtype=config.floatX)}
)
) == (2, 3) ) == (2, 3)
# Test parameter broadcasting # Test parameter broadcasting
node = normal(vec).owner mu = vec
new_inputs = node.inputs.copy() sigma = pt.as_tensor(np.array(1.0))
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu out = normal(mu, sigma)
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma new_mu = tensor("mu", shape=(10, 5))
vect_node = vectorize_node(node, *new_inputs) new_sigma = tensor("sigma", shape=(10,))
assert vect_node.op is normal vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 5) assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with non-expanding size # Test parameter broadcasting with non-expanding size
node = normal(vec, size=(5,)).owner mu = vec
new_inputs = node.inputs.copy() sigma = pt.as_tensor(np.array(1.0))
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu out = normal(mu, sigma, size=(5,))
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma new_mu = tensor("mu", shape=(10, 5))
vect_node = vectorize_node(node, *new_inputs) new_sigma = tensor("sigma", shape=(10,))
assert vect_node.op is normal vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 5) assert vect_node.default_output().type.shape == (10, 5)
node = normal(vec, size=(5,)).owner mu = vec
new_inputs = node.inputs.copy() sigma = pt.as_tensor(np.array(1.0))
new_inputs[3] = tensor("mu", shape=(1, 5)) # mu out = normal(mu, sigma, size=(5,))
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma new_mu = tensor("mu", shape=(1, 5)) # mu
vect_node = vectorize_node(node, *new_inputs) new_sigma = tensor("sigma", shape=(10,)) # sigma
assert vect_node.op is normal vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 5) assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with expanding size # Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner mu = vec
new_inputs = node.inputs.copy() sigma = pt.as_tensor(np.array(1.0))
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu out = normal(mu, sigma, size=(2, 5))
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma new_mu = tensor("mu", shape=(1, 5))
vect_node = vectorize_node(node, *new_inputs) new_sigma = tensor("sigma", shape=(10,))
assert vect_node.op is normal vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 2, 5) assert vect_node.default_output().type.shape == (10, 2, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论