提交 31304be2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Blockwise: Handle integers in signature and core_op.infer_shape

上级 ee47dcc9
......@@ -8,7 +8,7 @@ from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import (
......@@ -371,8 +371,12 @@ class Blockwise(COp):
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
for core_out_shape in safe_core_output_shapes:
for o, core_out_dim in enumerate(core_out_shape):
if set_dummy_core_inputs & set(
explicit_graph_inputs([core_out_dim])
if (
# Some Ops return integers / literals from infer_shape...
# If it's not a Variable it can't depend on inputs
isinstance(core_out_dim, Variable)
and set_dummy_core_inputs
& set(explicit_graph_inputs([core_out_dim]))
):
core_out_shape[o] = None
......@@ -386,9 +390,16 @@ class Blockwise(COp):
):
core_out_shape = []
for i, dim_name in enumerate(sig):
# The output dim is the same as another input dim
if dim_name in core_dims:
# The output dim is the same as another input dim
core_out_shape.append(core_dims[dim_name])
elif str.isnumeric(dim_name):
# The core_dim has a constant size
from pytensor.tensor.basic import constant
core_out_shape.append(
constant(np.array(int(dim_name), dtype="int64"))
)
else:
if safe_core_out_shape is None:
# Extract the core shape from the core_op infer_shape on demand
......
......@@ -19,6 +19,7 @@ from pytensor.tensor import (
dmatrix,
log,
matrices,
matrix,
ones_like,
scalar,
tensor,
......@@ -352,6 +353,32 @@ def test_blockwise_infer_core_shape():
assert tuple(d_shape_fn(a_test, b_test)) == (5, 0)
def test_infer_shape_literals():
# Define a CoreOp whose infer_shape is (symbolic operation on itself, literal)
# Then tell Blockwise that the first dimension is constant.
# The Op has no perform method, so it will fail to evaluate if the infer_shape of that dimension isn't ignored
class TestCoreOp(Op):
def make_node(self, x):
assert x.type.ndim == 0
return Apply(self, [x], [matrix()])
def perform(self, node, inputs, outputs):
raise NotImplementedError()
def infer_shape(self, fgraph, node, input_shapes):
y = node.outputs[0]
# Apparently it's valid to return integers in infer_shape.
# DimShuffle does this. Modify test if that is no longer allowed.
return [(y[0][0].astype(int), 3)]
op = Blockwise(TestCoreOp(), signature="()->(2,a)")
x = scalar("x")
y = op(x)
fn = function([x], y.shape)
assert tuple(fn(0)) == (2, 3)
class BlockwiseOpTester:
"""Base class to test Blockwise works for specific Ops"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论