提交 ddd73220 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Use static shapes in outputs of Elemwise

上级 70cf7e3b
...@@ -418,12 +418,29 @@ class Elemwise(OpenMPOp): ...@@ -418,12 +418,29 @@ class Elemwise(OpenMPOp):
# of all inputs in parallel... the all() gives us each output # of all inputs in parallel... the all() gives us each output
# broadcastable bit in turn. # broadcastable bit in turn.
def get_most_specialized_shape(shapes):
if None not in shapes:
# We could check if shapes are valid under broadcasting
# len(set(dims).discard(1)) <= 1
return max(shapes)
known_shapes = [shape for shape in shapes if shape is not None]
if known_shapes:
largest_known_shape = max(known_shapes)
# If largest known shape is 1, and there is an unknown shape, we don't
# know the final shape, because this could be broadcasted
if largest_known_shape > 1:
# Again, we could check that known shapes are valid under broacasting
return largest_known_shape
return None
# it is multiplied by nout because Elemwise supports multiple outputs # it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them) # (nout of them)
out_broadcastables = [ out_shapes = [
[ [
all(bcast) get_most_specialized_shape(shape)
for bcast in zip(*[input.type.broadcastable for input in inputs]) for shape in zip(*[input.type.shape for input in inputs])
] ]
] * shadow.nout ] * shadow.nout
...@@ -432,10 +449,10 @@ class Elemwise(OpenMPOp): ...@@ -432,10 +449,10 @@ class Elemwise(OpenMPOp):
if inplace_pattern: if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items(): for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip( for ob, ib in zip(
out_broadcastables[overwriter], out_shapes[overwriter],
inputs[overwritten].type.broadcastable, inputs[overwritten].type.broadcastable,
): ):
if ib and not ob: if ib and not ob == 1:
raise ValueError( raise ValueError(
"Operation cannot be done inplace on an input " "Operation cannot be done inplace on an input "
"with broadcasted dimensions." "with broadcasted dimensions."
...@@ -451,8 +468,8 @@ class Elemwise(OpenMPOp): ...@@ -451,8 +468,8 @@ class Elemwise(OpenMPOp):
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern), ([i.type.dtype for i in inputs], out_dtypes, inplace_pattern),
) )
) )
assert len(out_dtypes) == len(out_broadcastables) assert len(out_dtypes) == len(out_shapes)
return out_dtypes, out_broadcastables, inputs return out_dtypes, out_shapes, inputs
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
...@@ -461,12 +478,10 @@ class Elemwise(OpenMPOp): ...@@ -461,12 +478,10 @@ class Elemwise(OpenMPOp):
using DimShuffle. using DimShuffle.
""" """
inputs = [as_tensor_variable(i) for i in inputs] inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_broadcastables, inputs = self.get_output_info( out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
DimShuffle, *inputs
)
outputs = [ outputs = [
TensorType(dtype=dtype, shape=broadcastable)() TensorType(dtype=dtype, shape=shape)()
for dtype, broadcastable in zip(out_dtypes, out_broadcastables) for dtype, shape in zip(out_dtypes, out_shapes)
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
......
...@@ -21,6 +21,7 @@ from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise ...@@ -21,6 +21,7 @@ from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.exceptions import ShapeError from aesara.tensor.exceptions import ShapeError
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any from aesara.tensor.math import any as at_any
from aesara.tensor.math import exp
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
bmatrix, bmatrix,
...@@ -854,6 +855,27 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -854,6 +855,27 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert all(isinstance(v.type, TensorType) for v in out_shape) assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_static_shape_unary(self):
x = tensor("float64", shape=(None, 1, 5))
exp(x).type.shape == (None, 1, 5)
def test_static_shape_binary(self):
x = tensor("float64", shape=(None, 5))
y = tensor("float64", shape=(None, 5))
assert (x + y).type.shape == (None, 5)
x = tensor("float64", shape=(None, 5))
y = tensor("float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(1, 5))
y = tensor("float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(None, 1))
y = tensor("float64", shape=(1, 1))
assert (x + y).type.shape == (None, 1)
def test_not_implemented_elemwise_grad(): def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op. # Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论