提交 31304be2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Blockwise: Handle integers in signature and core_op.infer_shape

上级 ee47dcc9
...@@ -8,7 +8,7 @@ from pytensor import config ...@@ -8,7 +8,7 @@ from pytensor import config
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import ( from pytensor.graph.replace import (
...@@ -371,8 +371,12 @@ class Blockwise(COp): ...@@ -371,8 +371,12 @@ class Blockwise(COp):
safe_core_output_shapes = [list(shape) for shape in core_output_shapes] safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
for core_out_shape in safe_core_output_shapes: for core_out_shape in safe_core_output_shapes:
for o, core_out_dim in enumerate(core_out_shape): for o, core_out_dim in enumerate(core_out_shape):
if set_dummy_core_inputs & set( if (
explicit_graph_inputs([core_out_dim]) # Some Ops return integers / literals from infer_shape...
# If it's not a Variable it can't depend on inputs
isinstance(core_out_dim, Variable)
and set_dummy_core_inputs
& set(explicit_graph_inputs([core_out_dim]))
): ):
core_out_shape[o] = None core_out_shape[o] = None
...@@ -386,9 +390,16 @@ class Blockwise(COp): ...@@ -386,9 +390,16 @@ class Blockwise(COp):
): ):
core_out_shape = [] core_out_shape = []
for i, dim_name in enumerate(sig): for i, dim_name in enumerate(sig):
# The output dim is the same as another input dim
if dim_name in core_dims: if dim_name in core_dims:
# The output dim is the same as another input dim
core_out_shape.append(core_dims[dim_name]) core_out_shape.append(core_dims[dim_name])
elif str.isnumeric(dim_name):
# The core_dim has a constant size
from pytensor.tensor.basic import constant
core_out_shape.append(
constant(np.array(int(dim_name), dtype="int64"))
)
else: else:
if safe_core_out_shape is None: if safe_core_out_shape is None:
# Extract the core shape from the core_op infer_shape on demand # Extract the core shape from the core_op infer_shape on demand
......
...@@ -19,6 +19,7 @@ from pytensor.tensor import ( ...@@ -19,6 +19,7 @@ from pytensor.tensor import (
dmatrix, dmatrix,
log, log,
matrices, matrices,
matrix,
ones_like, ones_like,
scalar, scalar,
tensor, tensor,
...@@ -352,6 +353,32 @@ def test_blockwise_infer_core_shape(): ...@@ -352,6 +353,32 @@ def test_blockwise_infer_core_shape():
assert tuple(d_shape_fn(a_test, b_test)) == (5, 0) assert tuple(d_shape_fn(a_test, b_test)) == (5, 0)
def test_infer_shape_literals():
# Define a CoreOp whose infer_shape is (symbolic operation on itself, literal)
# Then tell Blockwise that the first dimension is constant.
# The Op has no perform method, so it will fail to evaluate if the infer_shape of that dimension isn't ignored
class TestCoreOp(Op):
def make_node(self, x):
assert x.type.ndim == 0
return Apply(self, [x], [matrix()])
def perform(self, node, inputs, outputs):
raise NotImplementedError()
def infer_shape(self, fgraph, node, input_shapes):
y = node.outputs[0]
# Apparently it's valid to return integers in infer_shape.
# DimShuffle does this. Modify test if that is no longer allowed.
return [(y[0][0].astype(int), 3)]
op = Blockwise(TestCoreOp(), signature="()->(2,a)")
x = scalar("x")
y = op(x)
fn = function([x], y.shape)
assert tuple(fn(0)) == (2, 3)
class BlockwiseOpTester: class BlockwiseOpTester:
"""Base class to test Blockwise works for specific Ops""" """Base class to test Blockwise works for specific Ops"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论