提交 6dd61726 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Blockwise infer shape from core Op

Sometimes `_create_dummy_core_node` can create a multi-node graph, where the root inputs are not `node.inputs`. Then infer_shape may bypass the intermediate nodes. This was the case with Subtensor, which introduces `ScalarFromTensor` nodes, but ignores them in the shape graph (for a cleaner graph)
上级 9578bd3b
......@@ -7,7 +7,7 @@ from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, ancestors
from pytensor.graph.basic import Apply, Constant, explicit_graph_inputs
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import (
......@@ -190,7 +190,7 @@ class Blockwise(Op):
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
if core_op_infer_shape is not None:
dummy_core_node = self._create_dummy_core_node(node.inputs)
dummy_core_inputs = dummy_core_node.inputs
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
core_input_shapes = [
input_shape[batch_ndims:] for input_shape in input_shapes
......@@ -214,7 +214,8 @@ class Blockwise(Op):
# of the core_node as the value is not constant across batch dims of the Blockwise
core_out_dim = core_output_shapes[o][i]
if not (
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
set(dummy_core_inputs)
& set(explicit_graph_inputs([core_out_dim]))
):
core_out_shape.append(core_out_dim)
continue
......
......@@ -264,9 +264,13 @@ def test_blockwise_infer_core_shape():
def make_node(self, a, b):
assert a.type.ndim == 1
assert b.type.ndim == 1
# Simulate make_node that introduces operations on inputs
a_identity = a.copy()
b_identity = b.copy()
c = tensor(shape=(None,))
d = tensor(shape=(None,))
return Apply(self, [a, b], [c, d])
return Apply(self, [a_identity, b_identity], [c, d])
def perform(self, node, inputs, outputs):
a, b = inputs
......@@ -277,9 +281,12 @@ def test_blockwise_infer_core_shape():
def infer_shape(self, fgraph, node, input_shapes):
# First output shape depends only on input_shapes
# Second output shape depends on input values
x, y = node.inputs
[(x_shape,), (y_shape,)] = input_shapes
return (x_shape + y_shape,), (x.sum() + y.sum(),)
a_identity, b_identity = node.inputs
# Simulate shape depending on original inputs, not the ones that go directly into the node
a = a_identity.owner.inputs[0]
b = b_identity.owner.inputs[0]
[(a_shape,), (b_shape,)] = input_shapes
return (a_shape + b_shape,), (a.sum() + b.sum(),)
blockwise_op = Blockwise(
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论