提交 d6858fe2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make Elemwise.infer_shape return TensorType-ed values

上级 90a0f733
from copy import copy from copy import copy
from typing import Tuple, Union from typing import List, Tuple, Union
import numpy as np import numpy as np
...@@ -29,6 +29,7 @@ from aesara.tensor.type import ( ...@@ -29,6 +29,7 @@ from aesara.tensor.type import (
float_dtypes, float_dtypes,
lvector, lvector,
) )
from aesara.tensor.var import TensorVariable
from aesara.utils import uniq from aesara.utils import uniq
...@@ -802,7 +803,7 @@ class Elemwise(OpenMPOp): ...@@ -802,7 +803,7 @@ class Elemwise(OpenMPOp):
else: else:
storage[0] = variable storage[0] = variable
def infer_shape(self, fgraph, node, i_shapes): def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1: if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError from aesara.tensor.basic_opt import ShapeError
...@@ -813,7 +814,8 @@ class Elemwise(OpenMPOp): ...@@ -813,7 +814,8 @@ class Elemwise(OpenMPOp):
out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [out_shape] # The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
return [tuple(as_tensor_variable(s) for s in out_shape)]
def _c_all(self, node, nodename, inames, onames, sub): def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
......
...@@ -26,6 +26,7 @@ from aesara.tensor.type import ( ...@@ -26,6 +26,7 @@ from aesara.tensor.type import (
bmatrix, bmatrix,
bscalar, bscalar,
discrete_dtypes, discrete_dtypes,
lscalar,
matrix, matrix,
scalar, scalar,
tensor, tensor,
...@@ -815,8 +816,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -815,8 +816,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert len(res_shape) == 1 assert len(res_shape) == 1
assert len(res_shape[0]) == 2 assert len(res_shape[0]) == 2
assert res_shape[0][0].data == 1 assert aesara.get_scalar_constant_value(res_shape[0][0]) == 1
assert res_shape[0][1].data == 1 assert aesara.get_scalar_constant_value(res_shape[0][1]) == 1
def test_multi_output(self): def test_multi_output(self):
class CustomElemwise(Elemwise): class CustomElemwise(Elemwise):
...@@ -841,6 +842,18 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -841,6 +842,18 @@ class TestElemwise(unittest_tools.InferShapeTester):
with pytest.raises(ShapeError): with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
def test_shape_types(self):
x = tensor(np.float64, (None, 1))
y = tensor(np.float64, (50, 10))
z = x * y
assert isinstance(z.owner.op, Elemwise)
(out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)])
assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_not_implemented_elemwise_grad(): def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op. # Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论