提交 5f681ce5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move scalarconsts_rest to math_opt

上级 e7cd653b
...@@ -113,24 +113,6 @@ def merge_broadcastables(broadcastables): ...@@ -113,24 +113,6 @@ def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)] return [all(bcast) for bcast in zip(*broadcastables)]
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
consts = []
origconsts = []
nonconsts = []
for i in inputs:
try:
v = get_scalar_constant_value(
i, elemwise=elemwise, only_process_constants=only_process_constants
)
consts.append(v)
origconsts.append(i)
except NotScalarConstantError:
nonconsts.append(i)
return consts, origconsts, nonconsts
def broadcast_like(value, template, fgraph, dtype=None): def broadcast_like(value, template, fgraph, dtype=None):
""" """
Return a Variable with the same shape and dtype as the template, Return a Variable with the same shape and dtype as the template,
......
...@@ -48,7 +48,6 @@ from aesara.tensor.basic_opt import ( ...@@ -48,7 +48,6 @@ from aesara.tensor.basic_opt import (
register_stabilize, register_stabilize,
register_uncanonicalize, register_uncanonicalize,
register_useless, register_useless,
scalarconsts_rest,
) )
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
...@@ -101,6 +100,24 @@ _logger = logging.getLogger("aesara.tensor.math_opt") ...@@ -101,6 +100,24 @@ _logger = logging.getLogger("aesara.tensor.math_opt")
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
consts = []
origconsts = []
nonconsts = []
for i in inputs:
try:
v = get_scalar_constant_value(
i, elemwise=elemwise, only_process_constants=only_process_constants
)
consts.append(v)
origconsts.append(i)
except NotScalarConstantError:
nonconsts.append(i)
return consts, origconsts, nonconsts
def get_constant(v): def get_constant(v):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论