提交 6dd61726 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Blockwise infer shape from core Op

Sometimes `_create_dummy_core_node` can create a multi-node graph, where the root inputs are not `node.inputs`. Then infer_shape may bypass the intermediate nodes. This was the case with Subtensor, which introduces `ScalarFromTensor` nodes, but ignores them in the shape graph (for a cleaner graph)
上级 9578bd3b
...@@ -7,7 +7,7 @@ from pytensor import config ...@@ -7,7 +7,7 @@ from pytensor import config
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, ancestors from pytensor.graph.basic import Apply, Constant, explicit_graph_inputs
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import ( from pytensor.graph.replace import (
...@@ -190,7 +190,7 @@ class Blockwise(Op): ...@@ -190,7 +190,7 @@ class Blockwise(Op):
core_op_infer_shape = getattr(self.core_op, "infer_shape", None) core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
if core_op_infer_shape is not None: if core_op_infer_shape is not None:
dummy_core_node = self._create_dummy_core_node(node.inputs) dummy_core_node = self._create_dummy_core_node(node.inputs)
dummy_core_inputs = dummy_core_node.inputs dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
core_input_shapes = [ core_input_shapes = [
input_shape[batch_ndims:] for input_shape in input_shapes input_shape[batch_ndims:] for input_shape in input_shapes
...@@ -214,7 +214,8 @@ class Blockwise(Op): ...@@ -214,7 +214,8 @@ class Blockwise(Op):
# of the core_node as the value is not constant across batch dims of the Blockwise # of the core_node as the value is not constant across batch dims of the Blockwise
core_out_dim = core_output_shapes[o][i] core_out_dim = core_output_shapes[o][i]
if not ( if not (
set(dummy_core_inputs) & set(ancestors([core_out_dim])) set(dummy_core_inputs)
& set(explicit_graph_inputs([core_out_dim]))
): ):
core_out_shape.append(core_out_dim) core_out_shape.append(core_out_dim)
continue continue
......
...@@ -264,9 +264,13 @@ def test_blockwise_infer_core_shape(): ...@@ -264,9 +264,13 @@ def test_blockwise_infer_core_shape():
def make_node(self, a, b): def make_node(self, a, b):
assert a.type.ndim == 1 assert a.type.ndim == 1
assert b.type.ndim == 1 assert b.type.ndim == 1
# Simulate make_node that introduces operations on inputs
a_identity = a.copy()
b_identity = b.copy()
c = tensor(shape=(None,)) c = tensor(shape=(None,))
d = tensor(shape=(None,)) d = tensor(shape=(None,))
return Apply(self, [a, b], [c, d]) return Apply(self, [a_identity, b_identity], [c, d])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
a, b = inputs a, b = inputs
...@@ -277,9 +281,12 @@ def test_blockwise_infer_core_shape(): ...@@ -277,9 +281,12 @@ def test_blockwise_infer_core_shape():
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
# First output shape depends only on input_shapes # First output shape depends only on input_shapes
# Second output shape depends on input values # Second output shape depends on input values
x, y = node.inputs a_identity, b_identity = node.inputs
[(x_shape,), (y_shape,)] = input_shapes # Simulate shape depending on original inputs, not the ones that go directly into the node
return (x_shape + y_shape,), (x.sum() + y.sum(),) a = a_identity.owner.inputs[0]
b = b_identity.owner.inputs[0]
[(a_shape,), (b_shape,)] = input_shapes
return (a_shape + b_shape,), (a.sum() + b.sum(),)
blockwise_op = Blockwise( blockwise_op = Blockwise(
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)" core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论