提交 6b71a80f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Small tweaks to XRV Ops

* Fix core_dims_needed calculation * Handle lazy dtype * Nicer __str__ with use of `name`
上级 10b84747
...@@ -392,6 +392,13 @@ class RandomVariable(RNGConsumerOp): ...@@ -392,6 +392,13 @@ class RandomVariable(RNGConsumerOp):
out_type = TensorType(dtype=self.dtype, shape=static_shape) out_type = TensorType(dtype=self.dtype, shape=static_shape)
outputs = (rng.type(), out_type()) outputs = (rng.type(), out_type())
if self.dtype == "floatX":
# Commit to a specific float type if the Op is still using "floatX"
dtype = config.floatX
props = self._props_dict()
props["dtype"] = dtype
self = type(self)(**props)
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def batch_ndim(self, node: Apply) -> int: def batch_ndim(self, node: Apply) -> int:
......
import warnings import warnings
import pytensor.xtensor.rewriting import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import ( from pytensor.xtensor.type import (
......
...@@ -5,8 +5,8 @@ from typing import Literal ...@@ -5,8 +5,8 @@ from typing import Literal
import pytensor.tensor.random.basic as ptr import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt from pytensor.xtensor.math import sqrt
from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XRV from pytensor.xtensor.vectorization import XRV
...@@ -14,6 +14,7 @@ def _as_xrv( ...@@ -14,6 +14,7 @@ def _as_xrv(
core_op: RandomVariable, core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None, core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None, core_out_dims_map: Sequence[int] | None = None,
name: str | None = None,
): ):
"""Helper function to define an XRV constructor. """Helper function to define an XRV constructor.
...@@ -41,7 +42,14 @@ def _as_xrv( ...@@ -41,7 +42,14 @@ def _as_xrv(
core_out_dims_map = tuple(range(core_op.ndim_supp)) core_out_dims_map = tuple(range(core_op.ndim_supp))
core_dims_needed = max( core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 max(
(
max((entry + 1 for entry in dims_map), default=0)
for dims_map in core_inps_dims_map
),
default=0,
),
max((entry + 1 for entry in core_out_dims_map), default=0),
) )
@wraps(core_op) @wraps(core_op)
...@@ -76,7 +84,10 @@ def _as_xrv( ...@@ -76,7 +84,10 @@ def _as_xrv(
extra_dims = {} extra_dims = {}
return XRV( return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys()) core_op,
core_dims=full_core_dims,
extra_dims=tuple(extra_dims.keys()),
name=name,
)(rng, *extra_dims.values(), *params) )(rng, *extra_dims.values(), *params)
return xrv_constructor return xrv_constructor
......
...@@ -116,7 +116,7 @@ def lower_rv(fgraph, node): ...@@ -116,7 +116,7 @@ def lower_rv(fgraph, node):
size = [*extra_dim_lengths, *param_batch_shape] size = [*extra_dim_lengths, *param_batch_shape]
# RVs are their own core Op # RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs new_next_rng, tensor_out = core_op.make_node(rng, size, *tensor_params).outputs
# Convert output Tensors to XTensors # Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims) new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
......
...@@ -142,8 +142,12 @@ class XRV(XOp, RNGConsumerOp): ...@@ -142,8 +142,12 @@ class XRV(XOp, RNGConsumerOp):
core_op, core_op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]], core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
extra_dims: tuple[str, ...], extra_dims: tuple[str, ...],
name: str | None = None,
): ):
super().__init__() super().__init__()
if name is None:
name = getattr(core_op, "name", None)
self.name = name
self.core_op = core_op self.core_op = core_op
inps_core_dims, out_core_dims = core_dims inps_core_dims, out_core_dims = core_dims
for operand_dims in (*inps_core_dims, out_core_dims): for operand_dims in (*inps_core_dims, out_core_dims):
...@@ -154,6 +158,15 @@ class XRV(XOp, RNGConsumerOp): ...@@ -154,6 +158,15 @@ class XRV(XOp, RNGConsumerOp):
raise ValueError("size_dims must be unique") raise ValueError("size_dims must be unique")
self.extra_dims = tuple(extra_dims) self.extra_dims = tuple(extra_dims)
def __str__(self):
if self.name is not None:
name = self.name
attrs = f"(core_dims={self.core_dims}, extra_dims={self.extra_dims})"
else:
name = self.__class__.__name__
attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})"
return f"{name}({attrs})"
def update(self, node): def update(self, node):
# RNG input and update are the first input and output respectively # RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]} return {node.inputs[0]: node.outputs[0]}
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
import pytensor.tensor.random as ptr import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr import pytensor.xtensor.random as pxr
from pytensor import function, shared from pytensor import config, function, shared
from pytensor.graph import rewrite_graph from pytensor.graph import rewrite_graph
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.tensor import broadcast_arrays, tensor from pytensor.tensor import broadcast_arrays, tensor
...@@ -112,6 +112,19 @@ def test_output_dim_does_not_map_from_input_dims(): ...@@ -112,6 +112,19 @@ def test_output_dim_does_not_map_from_input_dims():
) )
def test_dtype():
x = normal(0, 1)
assert x.type.dtype == config.floatX
with config.change_flags(floatX="float64"):
x = normal(0, 1)
assert x.type.dtype == "float64"
with config.change_flags(floatX="float32"):
x = normal(0, 1)
assert x.type.dtype == "float32"
def test_normal(): def test_normal():
rng = random_generator_type("rng") rng = random_generator_type("rng")
c_size = tensor("c_size", shape=(), dtype=int) c_size = tensor("c_size", shape=(), dtype=int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论