提交 064e72f4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Only use input shapes to compute output shape in Elemwise.infer_shape

上级 22416ba6
......@@ -16,7 +16,6 @@ from aesara.misc.frozendict import frozendict
from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, Printer, pprint
from aesara.scalar import get_scalar_type
from aesara.scalar.basic import ScalarType
from aesara.scalar.basic import bool as scalar_bool
from aesara.scalar.basic import identity as scalar_identity
from aesara.scalar.basic import transfer_type, upcast
......@@ -804,37 +803,17 @@ class Elemwise(OpenMPOp):
storage[0] = variable
def infer_shape(self, fgraph, node, i_shapes):
rval = []
for o in node.outputs:
oshp = []
for dim, b in enumerate(o.type.broadcastable):
b_dim = None
if b:
# this is broadcastable
b_dim = 1
else:
# there must be some input that is not broadcastable in
# dimension 'dim'
for ishp, i in zip(i_shapes, node.inputs):
if isinstance(i.type, ScalarType):
continue # we skip scalar
if not i.type.broadcastable[dim]:
# input i is not broadcastable in position dim
# therefore if its shape is known, we can use it
# as the output shape
if ishp[dim]:
b_dim = ishp[dim]
break
# b_dim might still be None, if every input's shape was unknown
# in dimension 'dim'
oshp.append(b_dim)
# TODO: it would be interesting to return the constraining
# information that if one of the inputs shape[dim] is known
# and another input's shape[dim] is not, that we can now assume
# that the other input's shape[dim] is the same as the first.
rval.append(tuple(oshp))
return rval
if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError
raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
)
out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [out_shape]
def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
......
......@@ -11,12 +11,13 @@ import aesara.scalar as aes
import tests.unittest_tools as utt
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second
from aesara.tensor.basic_opt import ShapeError
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any
......@@ -800,6 +801,46 @@ class TestElemwise(unittest_tools.InferShapeTester):
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op"
def test_partial_static_shape_info(self):
"""Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting."""
x = TensorType("floatX", shape=(None, None))()
z = Elemwise(aes.add)(x, x)
x_inferred_shape = (aes.constant(1), aes.constant(1))
res_shape = z.owner.op.infer_shape(
None, z.owner, [x_inferred_shape, x_inferred_shape]
)
assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert res_shape[0][0].data == 1
assert res_shape[0][1].data == 1
def test_multi_output(self):
class CustomElemwise(Elemwise):
def make_node(self, *args):
res = super().make_node(*args)
return Apply(
self,
res.inputs,
# Return two outputs
[
TensorType(dtype="float64", shape=(None, None))()
for i in range(2)
],
)
z_1, z_2 = CustomElemwise(aes.add)(
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
)
in_1_shape = (aes.constant(1), aes.constant(1))
with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论