提交 881f3494 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move _fill_chain to aesara.tensor.math_opt

上级 f64d243d
......@@ -88,12 +88,6 @@ _logger = logging.getLogger("aesara.tensor.basic_opt")
_logger.addFilter(NoDuplicateOptWarningFilter())
def _fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = fill(i, new_out)
return [new_out]
def encompasses_broadcastable(b1, b2):
"""
......
......@@ -40,7 +40,6 @@ from aesara.tensor.basic import (
)
from aesara.tensor.basic_opt import (
FusionOptimizer,
_fill_chain,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
......@@ -104,6 +103,12 @@ _logger = logging.getLogger("aesara.tensor.math_opt")
_logger.addFilter(NoDuplicateOptWarningFilter())
def fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = fill(i, new_out)
return [new_out]
@register_canonicalize
@register_stabilize
@local_optimizer([Dot])
......@@ -1001,7 +1006,7 @@ class AlgebraicCanonizer(LocalOptimizer):
assert (new.type == out.type) == (not (new.type != out.type))
if not (new.type == out.type):
new = _fill_chain(new, node.inputs)[0]
new = fill_chain(new, node.inputs)[0]
if new.type == out.type:
# This happen with test
......@@ -1792,7 +1797,7 @@ def local_mul_zero(fgraph, node):
# print 'MUL by value', value, node.inputs
if value == 0:
# print '... returning zeros'
return _fill_chain(_asarray(0, dtype=otype.dtype), node.inputs)
return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero)
......@@ -2074,8 +2079,8 @@ register_specialize(local_mul_specialize)
@local_optimizer([add])
def local_add_specialize(fgraph, node):
def fill_chain(v):
out = _fill_chain(v, node.inputs)
def _fill_chain(v):
out = fill_chain(v, node.inputs)
return out
# here, we are past the point of canonicalization, so we don't want
......@@ -2099,12 +2104,12 @@ def local_add_specialize(fgraph, node):
# Reuse call to constant for cache()
cst = constant(np.zeros((1,) * ndim, dtype=dtype))
assert cst.type.broadcastable == (True,) * ndim
return fill_chain(cst)
return _fill_chain(cst)
if len(new_inputs) == 1:
ret = fill_chain(new_inputs[0])
ret = _fill_chain(new_inputs[0])
else:
ret = fill_chain(add(*new_inputs))
ret = _fill_chain(add(*new_inputs))
# The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0.
if ret[0].dtype != dtype:
......@@ -2223,7 +2228,7 @@ def local_log1p(fgraph, node):
ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return _fill_chain(log1p(ninp), scalar_inputs)
return fill_chain(log1p(ninp), scalar_inputs)
elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
......@@ -3496,7 +3501,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
out = _fill_chain(
out = fill_chain(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
scalar_inputs,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论