提交 ddd73220 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Use static shapes in outputs of Elemwise

上级 70cf7e3b
......@@ -418,12 +418,29 @@ class Elemwise(OpenMPOp):
# of all inputs in parallel... the all() gives us each output
# broadcastable bit in turn.
def get_most_specialized_shape(shapes):
if None not in shapes:
# We could check if shapes are valid under broadcasting
# len(set(dims).discard(1)) <= 1
return max(shapes)
known_shapes = [shape for shape in shapes if shape is not None]
if known_shapes:
largest_known_shape = max(known_shapes)
# If largest known shape is 1, and there is an unknown shape, we don't
# know the final shape, because this could be broadcasted
if largest_known_shape > 1:
# Again, we could check that known shapes are valid under broacasting
return largest_known_shape
return None
# it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them)
out_broadcastables = [
out_shapes = [
[
all(bcast)
for bcast in zip(*[input.type.broadcastable for input in inputs])
get_most_specialized_shape(shape)
for shape in zip(*[input.type.shape for input in inputs])
]
] * shadow.nout
......@@ -432,10 +449,10 @@ class Elemwise(OpenMPOp):
if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip(
out_broadcastables[overwriter],
out_shapes[overwriter],
inputs[overwritten].type.broadcastable,
):
if ib and not ob:
if ib and not ob == 1:
raise ValueError(
"Operation cannot be done inplace on an input "
"with broadcasted dimensions."
......@@ -451,8 +468,8 @@ class Elemwise(OpenMPOp):
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern),
)
)
assert len(out_dtypes) == len(out_broadcastables)
return out_dtypes, out_broadcastables, inputs
assert len(out_dtypes) == len(out_shapes)
return out_dtypes, out_shapes, inputs
def make_node(self, *inputs):
"""
......@@ -461,12 +478,10 @@ class Elemwise(OpenMPOp):
using DimShuffle.
"""
inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_broadcastables, inputs = self.get_output_info(
DimShuffle, *inputs
)
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
outputs = [
TensorType(dtype=dtype, shape=broadcastable)()
for dtype, broadcastable in zip(out_dtypes, out_broadcastables)
TensorType(dtype=dtype, shape=shape)()
for dtype, shape in zip(out_dtypes, out_shapes)
]
return Apply(self, inputs, outputs)
......
......@@ -21,6 +21,7 @@ from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.exceptions import ShapeError
from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any
from aesara.tensor.math import exp
from aesara.tensor.type import (
TensorType,
bmatrix,
......@@ -854,6 +855,27 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_static_shape_unary(self):
x = tensor("float64", shape=(None, 1, 5))
exp(x).type.shape == (None, 1, 5)
def test_static_shape_binary(self):
x = tensor("float64", shape=(None, 5))
y = tensor("float64", shape=(None, 5))
assert (x + y).type.shape == (None, 5)
x = tensor("float64", shape=(None, 5))
y = tensor("float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(1, 5))
y = tensor("float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(None, 1))
y = tensor("float64", shape=(1, 1))
assert (x + y).type.shape == (None, 1)
def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论