提交 6b71a80f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Small tweaks to XRV Ops

* Fix core_dims_needed calculation * Handle lazy dtype * Nicer __str__ with use of `name`
上级 10b84747
......@@ -392,6 +392,13 @@ class RandomVariable(RNGConsumerOp):
out_type = TensorType(dtype=self.dtype, shape=static_shape)
outputs = (rng.type(), out_type())
if self.dtype == "floatX":
# Commit to a specific float type if the Op is still using "floatX"
dtype = config.floatX
props = self._props_dict()
props["dtype"] = dtype
self = type(self)(**props)
return Apply(self, inputs, outputs)
def batch_ndim(self, node: Apply) -> int:
......
import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg
from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
......
......@@ -5,8 +5,8 @@ from typing import Literal
import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt
from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XRV
......@@ -14,6 +14,7 @@ def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
name: str | None = None,
):
"""Helper function to define an XRV constructor.
......@@ -41,7 +42,14 @@ def _as_xrv(
core_out_dims_map = tuple(range(core_op.ndim_supp))
core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
max(
(
max((entry + 1 for entry in dims_map), default=0)
for dims_map in core_inps_dims_map
),
default=0,
),
max((entry + 1 for entry in core_out_dims_map), default=0),
)
@wraps(core_op)
......@@ -76,7 +84,10 @@ def _as_xrv(
extra_dims = {}
return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
core_op,
core_dims=full_core_dims,
extra_dims=tuple(extra_dims.keys()),
name=name,
)(rng, *extra_dims.values(), *params)
return xrv_constructor
......
......@@ -116,7 +116,7 @@ def lower_rv(fgraph, node):
size = [*extra_dim_lengths, *param_batch_shape]
# RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs
new_next_rng, tensor_out = core_op.make_node(rng, size, *tensor_params).outputs
# Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
......
......@@ -142,8 +142,12 @@ class XRV(XOp, RNGConsumerOp):
core_op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
extra_dims: tuple[str, ...],
name: str | None = None,
):
super().__init__()
if name is None:
name = getattr(core_op, "name", None)
self.name = name
self.core_op = core_op
inps_core_dims, out_core_dims = core_dims
for operand_dims in (*inps_core_dims, out_core_dims):
......@@ -154,6 +158,15 @@ class XRV(XOp, RNGConsumerOp):
raise ValueError("size_dims must be unique")
self.extra_dims = tuple(extra_dims)
def __str__(self):
if self.name is not None:
name = self.name
attrs = f"(core_dims={self.core_dims}, extra_dims={self.extra_dims})"
else:
name = self.__class__.__name__
attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})"
return f"{name}({attrs})"
def update(self, node):
# RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]}
......
......@@ -7,7 +7,7 @@ import pytest
import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr
from pytensor import function, shared
from pytensor import config, function, shared
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import equal_computations
from pytensor.tensor import broadcast_arrays, tensor
......@@ -112,6 +112,19 @@ def test_output_dim_does_not_map_from_input_dims():
)
def test_dtype():
x = normal(0, 1)
assert x.type.dtype == config.floatX
with config.change_flags(floatX="float64"):
x = normal(0, 1)
assert x.type.dtype == "float64"
with config.change_flags(floatX="float32"):
x = normal(0, 1)
assert x.type.dtype == "float32"
def test_normal():
rng = random_generator_type("rng")
c_size = tensor("c_size", shape=(), dtype=int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论