提交 18ba52cd authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use infer_shape of core_op to infer Blockwise core shapes

This can only be done when the output of infer_shape of the core_op depends only on the input shapes, and not their values.
上级 ef97287b
......@@ -6,7 +6,8 @@ import numpy as np
from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, ancestors
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import (
......@@ -185,15 +186,40 @@ class Blockwise(Op):
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
# Try to extract the core shapes from the core_op
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
if core_op_infer_shape is not None:
dummy_core_node = self._create_dummy_core_node(node.inputs)
dummy_core_inputs = dummy_core_node.inputs
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
core_input_shapes = [
input_shape[batch_ndims:] for input_shape in input_shapes
]
core_output_shapes = core_op_infer_shape(
dummy_fgraph, dummy_core_node, core_input_shapes
)
out_shapes = []
for output, sig in zip(node.outputs, self.outputs_sig, strict=True):
for o, (output, sig) in enumerate(
zip(node.outputs, self.outputs_sig, strict=True)
):
core_out_shape = []
for i, dim_name in enumerate(sig):
# The output dim is the same as another input dim
if dim_name in core_dims:
core_out_shape.append(core_dims[dim_name])
else:
# TODO: We could try to make use of infer_shape of core_op
if core_op_infer_shape is not None:
# If the input values are needed to compute the dimension length, we can't use the infer_shape
# of the core_node as the value is not constant across batch dims of the Blockwise
core_out_dim = core_output_shapes[o][i]
if not (
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
):
core_out_shape.append(core_out_dim)
continue
# Fallback shape requires evaluating the Blockwise Op
core_out_shape.append(Shape_i(batch_ndims + i)(output))
out_shapes.append((*batch_shape, *core_out_shape))
......
......@@ -259,6 +259,58 @@ def test_blockwise_shape():
assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4)
def test_blockwise_infer_core_shape():
class TestOpWithInferShape(Op):
def make_node(self, a, b):
assert a.type.ndim == 1
assert b.type.ndim == 1
c = tensor(shape=(None,))
d = tensor(shape=(None,))
return Apply(self, [a, b], [c, d])
def perform(self, node, inputs, outputs):
a, b = inputs
c, d = outputs
c[0] = np.arange(a.size + b.size)
d[0] = np.arange(a.sum() + b.sum())
def infer_shape(self, fgraph, node, input_shapes):
# First output shape depends only on input_shapes
# Second output shape depends on input values
x, y = node.inputs
[(x_shape,), (y_shape,)] = input_shapes
return (x_shape + y_shape,), (x.sum() + y.sum(),)
blockwise_op = Blockwise(
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"
)
a = tensor("a", shape=(5, 3))
b = tensor("b", shape=(1, 4))
c, d = blockwise_op(a, b)
assert c.type.shape == (5, None)
assert d.type.shape == (5, None)
c_shape_fn = pytensor.function([a, b], c.shape)
# c_shape can be computed from the input shapes alone
assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape)
for n in c_shape_fn.maker.fgraph.apply_nodes
)
d_shape_fn = pytensor.function([a, b], d.shape)
# d_shape cannot be computed from the input shapes alone
assert any(
isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape)
for n in d_shape_fn.maker.fgraph.apply_nodes
)
a_test = np.zeros(a.type.shape, dtype=a.type.dtype)
b_test = np.zeros(b.type.shape, dtype=b.type.dtype)
assert tuple(c_shape_fn(a_test, b_test)) == (5, 7)
assert tuple(d_shape_fn(a_test, b_test)) == (5, 0)
class BlockwiseOpTester:
"""Base class to test Blockwise works for specific Ops"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论