提交 4265d4fb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix OpFromGraph.infer_shape when shapeless types are used

上级 94254c7d
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
from collections import OrderedDict from collections import OrderedDict
from functools import partial, reduce from functools import partial
import aesara.tensor as aet import aesara.tensor as aet
from aesara.compile.function.pfunc import rebuild_collect_shared from aesara.compile.function.pfunc import rebuild_collect_shared
...@@ -781,23 +780,29 @@ class OpFromGraph(Op): ...@@ -781,23 +780,29 @@ class OpFromGraph(Op):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
out_shp = infer_shape(self.local_outputs, self.local_inputs, shapes) # TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.local_outputs, self.local_inputs, shapes)
# Clone the output shape so that shape are computed from outer inputs. # Clone the output shape so that shape are computed from outer inputs.
# Note: # Note:
# Here we can do it more simply like: # Here we could do it more simply like:
# ret = [aesara.clone_replace(shp, replace=repl) for shp in out_shp] # `ret = [aesara.clone_replace(shp, replace=repl) for shp in out_shp]`
# But doing it multiple time could duplicate common subgraph between # But doing it multiple time could duplicate common subgraph between
# each shape call. Aesara optimizer will clean this up later, but this # each shape call. Aesara optimizer will clean this up later, but this
# will ask extra work to the optimizer. # will make extra work for the optimizer.
repl = dict(zip(self.local_inputs, node.inputs)) repl = dict(zip(self.local_inputs, node.inputs))
cloned = clone_replace(reduce(tuple.__add__, out_shp), replace=repl) clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
ret = [] ret = []
used = 0 used = 0
for i in range(len(out_shp)): for i, out_shape in enumerate(out_shapes):
nb = len(out_shp[i]) if out_shape is None:
ret.append(cloned[used : used + nb]) ret.append(None)
used += nb else:
nb = len(out_shape)
ret.append(cloned[used : used + nb])
used += nb
return ret return ret
......
...@@ -8,7 +8,11 @@ from aesara.compile.builders import OpFromGraph ...@@ -8,7 +8,11 @@ from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function from aesara.compile.function import function
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer
from aesara.tensor.math import dot, exp from aesara.tensor.math import dot, exp
from aesara.tensor.math import round as aet_round from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sigmoid from aesara.tensor.math import sigmoid
...@@ -16,6 +20,7 @@ from aesara.tensor.math import sum as aet_sum ...@@ -16,6 +20,7 @@ from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools from tests import unittest_tools
from tests.graph.utils import MyVariable
class TestOpFromGraph(unittest_tools.InferShapeTester): class TestOpFromGraph(unittest_tools.InferShapeTester):
...@@ -400,6 +405,22 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -400,6 +405,22 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
OpFromGraph, OpFromGraph,
) )
# Make sure `OpFromGraph.infer_shape` can handle objects without a
# shape
x = MyVariable("x")
y = matrix("y")
z = as_tensor([1, 2])
op_graph = OpFromGraph([x, y, z], [x, y])
op_var = op_graph(x, y, z)
fg = FunctionGraph(outputs=[op_var[1]], clone=False)
opt_res = optimize_graph(fg, custom_opt=ShapeOptimizer())
assert opt_res.shape_feature.shape_of[x] is None
assert opt_res.shape_feature.shape_of[z][0].data == 2
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_compute_test_value(self): def test_compute_test_value(self):
x = scalar("x") x = scalar("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论