提交 881f3494 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move _fill_chain to aesara.tensor.math_opt

上级 f64d243d
...@@ -88,12 +88,6 @@ _logger = logging.getLogger("aesara.tensor.basic_opt") ...@@ -88,12 +88,6 @@ _logger = logging.getLogger("aesara.tensor.basic_opt")
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
def _fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = fill(i, new_out)
return [new_out]
def encompasses_broadcastable(b1, b2): def encompasses_broadcastable(b1, b2):
""" """
......
...@@ -40,7 +40,6 @@ from aesara.tensor.basic import ( ...@@ -40,7 +40,6 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.basic_opt import ( from aesara.tensor.basic_opt import (
FusionOptimizer, FusionOptimizer,
_fill_chain,
broadcast_like, broadcast_like,
encompasses_broadcastable, encompasses_broadcastable,
fuse_seqopt, fuse_seqopt,
...@@ -104,6 +103,12 @@ _logger = logging.getLogger("aesara.tensor.math_opt") ...@@ -104,6 +103,12 @@ _logger = logging.getLogger("aesara.tensor.math_opt")
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
def fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = fill(i, new_out)
return [new_out]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@local_optimizer([Dot]) @local_optimizer([Dot])
...@@ -1001,7 +1006,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -1001,7 +1006,7 @@ class AlgebraicCanonizer(LocalOptimizer):
assert (new.type == out.type) == (not (new.type != out.type)) assert (new.type == out.type) == (not (new.type != out.type))
if not (new.type == out.type): if not (new.type == out.type):
new = _fill_chain(new, node.inputs)[0] new = fill_chain(new, node.inputs)[0]
if new.type == out.type: if new.type == out.type:
# This happen with test # This happen with test
...@@ -1792,7 +1797,7 @@ def local_mul_zero(fgraph, node): ...@@ -1792,7 +1797,7 @@ def local_mul_zero(fgraph, node):
# print 'MUL by value', value, node.inputs # print 'MUL by value', value, node.inputs
if value == 0: if value == 0:
# print '... returning zeros' # print '... returning zeros'
return _fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero) register_canonicalize(local_mul_zero)
...@@ -2074,8 +2079,8 @@ register_specialize(local_mul_specialize) ...@@ -2074,8 +2079,8 @@ register_specialize(local_mul_specialize)
@local_optimizer([add]) @local_optimizer([add])
def local_add_specialize(fgraph, node): def local_add_specialize(fgraph, node):
def fill_chain(v): def _fill_chain(v):
out = _fill_chain(v, node.inputs) out = fill_chain(v, node.inputs)
return out return out
# here, we are past the point of canonicalization, so we don't want # here, we are past the point of canonicalization, so we don't want
...@@ -2099,12 +2104,12 @@ def local_add_specialize(fgraph, node): ...@@ -2099,12 +2104,12 @@ def local_add_specialize(fgraph, node):
# Reuse call to constant for cache() # Reuse call to constant for cache()
cst = constant(np.zeros((1,) * ndim, dtype=dtype)) cst = constant(np.zeros((1,) * ndim, dtype=dtype))
assert cst.type.broadcastable == (True,) * ndim assert cst.type.broadcastable == (True,) * ndim
return fill_chain(cst) return _fill_chain(cst)
if len(new_inputs) == 1: if len(new_inputs) == 1:
ret = fill_chain(new_inputs[0]) ret = _fill_chain(new_inputs[0])
else: else:
ret = fill_chain(add(*new_inputs)) ret = _fill_chain(add(*new_inputs))
# The dtype should not be changed. It can happen if the input # The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0. # that was forcing upcasting was equal to 0.
if ret[0].dtype != dtype: if ret[0].dtype != dtype:
...@@ -2223,7 +2228,7 @@ def local_log1p(fgraph, node): ...@@ -2223,7 +2228,7 @@ def local_log1p(fgraph, node):
ninp = nonconsts[0] ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype: if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype) ninp = ninp.astype(node.outputs[0].dtype)
return _fill_chain(log1p(ninp), scalar_inputs) return fill_chain(log1p(ninp), scalar_inputs)
elif log_arg.owner and log_arg.owner.op == sub: elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
...@@ -3496,7 +3501,7 @@ def local_reciprocal_1_plus_exp(fgraph, node): ...@@ -3496,7 +3501,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
if len(nonconsts) == 1: if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp: if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1): if scalars_ and np.allclose(np.sum(scalars_), 1):
out = _fill_chain( out = fill_chain(
sigmoid(neg(nonconsts[0].owner.inputs[0])), sigmoid(neg(nonconsts[0].owner.inputs[0])),
scalar_inputs, scalar_inputs,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论