提交 d6858fe2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make Elemwise.infer_shape return TensorType-ed values

上级 90a0f733
from copy import copy
from typing import Tuple, Union
from typing import List, Tuple, Union
import numpy as np
......@@ -29,6 +29,7 @@ from aesara.tensor.type import (
float_dtypes,
lvector,
)
from aesara.tensor.var import TensorVariable
from aesara.utils import uniq
......@@ -802,7 +803,7 @@ class Elemwise(OpenMPOp):
else:
storage[0] = variable
def infer_shape(self, fgraph, node, i_shapes):
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError
......@@ -813,7 +814,8 @@ class Elemwise(OpenMPOp):
out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [out_shape]
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
return [tuple(as_tensor_variable(s) for s in out_shape)]
def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
......
......@@ -26,6 +26,7 @@ from aesara.tensor.type import (
bmatrix,
bscalar,
discrete_dtypes,
lscalar,
matrix,
scalar,
tensor,
......@@ -815,8 +816,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert res_shape[0][0].data == 1
assert res_shape[0][1].data == 1
assert aesara.get_scalar_constant_value(res_shape[0][0]) == 1
assert aesara.get_scalar_constant_value(res_shape[0][1]) == 1
def test_multi_output(self):
class CustomElemwise(Elemwise):
......@@ -841,6 +842,18 @@ class TestElemwise(unittest_tools.InferShapeTester):
with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
def test_shape_types(self):
x = tensor(np.float64, (None, 1))
y = tensor(np.float64, (50, 10))
z = x * y
assert isinstance(z.owner.op, Elemwise)
(out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)])
assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论