提交 dea9940a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Vectorize CumOp and simplify infer_shape for vector case

上级 6f8bb555
......@@ -13,6 +13,7 @@ from pytensor.gradient import (
)
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import EnumList, Generic
......@@ -360,7 +361,7 @@ class CumOp(COp):
)
def infer_shape(self, fgraph, node, shapes):
if self.axis is None:
if self.axis is None and len(shapes[0]) > 1:
return [(prod(shapes[0]),)] # Flatten
return shapes
......@@ -473,6 +474,25 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)
@_vectorize_node.register(CumOp)
def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
"""Vectorize the CumOp to work on a batch of inputs."""
[original_x] = node.inputs
batch_ndim = batch_x.ndim - original_x.ndim
axis = op.axis
if axis is None and original_x.ndim == 1:
axis = 0
elif axis is not None:
axis = normalize_axis_index(op.axis, original_x.ndim)
if axis is None:
# Ravel all unbatched dimensions and perform CumOp on the last axis
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
else:
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论