提交 639b0871 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Get rid of redundant checks in tracked node_rewriters

上级 cb2c40ba
......@@ -328,7 +328,7 @@ def local_func_inv(fgraph, node):
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([log, log1p, exp, expm1])
def local_exp_log(fgraph, node):
x = node.inputs[0]
......@@ -368,7 +368,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
x = node.inputs[0]
......@@ -431,11 +431,7 @@ def local_sumsqr2dot(fgraph, node):
``pt.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))``
and converts it to ``pt.dot(pt.sqr(G), pt.sqr(W).sum(axis=0))``.
"""
if (
isinstance(node.op, Sum)
and isinstance(node.op.scalar_op, ps.Add)
and node.op.axis == (1, 2)
):
if node.op.axis == (1, 2):
in1 = node.inputs[0]
out = node.outputs[0]
......@@ -479,7 +475,7 @@ def local_mul_exp_to_exp_add(fgraph, node):
n.owner.inputs[0]
for n in node.inputs
if n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op, Elemwise)
and isinstance(n.owner.op.scalar_op, ps.Exp)
]
# Can only do any rewrite if there are at least two exp-s
......@@ -523,7 +519,7 @@ def local_mul_pow_to_pow_add(fgraph, node):
for n in node.inputs:
if (
n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op, Elemwise)
and isinstance(n.owner.op.scalar_op, ps.Pow)
):
base_node = n.owner.inputs[0]
......@@ -567,28 +563,27 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize
@register_specialize
@register_canonicalize
@node_rewriter([Elemwise])
@node_rewriter([sub])
def local_expm1(fgraph, node):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Sub):
in1, in2 = node.inputs
out = node.outputs[0]
in1, in2 = node.inputs
out = node.outputs[0]
if (
in1.owner
and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp)
and extract_constant(in2, only_process_constants=False) == 1
):
in11 = in1.owner.inputs[0]
new_out = expm1(in11)
if (
in1.owner
and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp)
and extract_constant(in2, only_process_constants=False) == 1
):
in11 = in1.owner.inputs[0]
new_out = expm1(in11)
if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype)
if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype)
if not out.type.is_super(new_out.type):
return
return [new_out]
if not out.type.is_super(new_out.type):
return
return [new_out]
@register_specialize
......@@ -625,8 +620,6 @@ def local_mul_switch_sink(fgraph, node):
part of the graph.
"""
if node.op != mul:
return False
for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == switch:
switch_node = i.owner
......@@ -705,8 +698,6 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details.
"""
if node.op != true_div and node.op != int_div:
return False
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner
......@@ -1235,8 +1226,7 @@ register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canon
@register_canonicalize
@node_rewriter([neg])
def local_neg_to_mul(fgraph, node):
if node.op == neg:
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
@register_specialize
......@@ -1347,17 +1337,12 @@ def local_sum_of_neg_to_neg_of_sum(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([sub])
def local_elemwise_sub_zeros(fgraph, node):
"""
Elemwise{sub}(X,X) -> zeros_like(X)
"""
if (
isinstance(node.op, Elemwise)
and node.op.scalar_op.nin == 2
and node.op.scalar_op == ps.sub
and node.inputs[0] == node.inputs[1]
):
if node.inputs[0] == node.inputs[1]:
res = zeros_like(node.inputs[0])
# Copy over stacktrace from previous output.
# This could help for failures due to out-of-memory.
......@@ -1400,8 +1385,6 @@ def local_useless_elemwise_comparison(fgraph, node):
the graph easier to read.
"""
if not isinstance(node.op, Elemwise):
return
if node.op.scalar_op.nin != 2:
return
......@@ -1590,14 +1573,13 @@ def local_sum_prod_all_to_none(fgraph, node):
Prod{0,1,...N} -> Prod{}
"""
if isinstance(node.op, Sum) or isinstance(node.op, Prod):
op_type = Sum if isinstance(node.op, Sum) else Prod
# if all the axes are named, then use None as a shorthand
# this permits more merging
if node.op.axis is None:
return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [op_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
op_type = Sum if isinstance(node.op, Sum) else Prod
# if all the axes are named, then use None as a shorthand
# this permits more merging
if node.op.axis is None:
return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [op_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize
......@@ -1609,35 +1591,34 @@ def local_op_of_op(fgraph, node):
Sum(Sum()) -> single Sum()
"""
if isinstance(node.op, Prod) or isinstance(node.op, Sum):
op_type = Sum if isinstance(node.op, Sum) else Prod
(node_inps,) = node.inputs
out_dtype = node.op.dtype
# This is done to make sure the rewrite doesn't affect other
# computations.
if len(fgraph.clients[node_inps]) == 1:
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None:
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
# figure out which axes were in the original sum
newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(
list(node_inps.owner.op.axis) + list(node.op.axis)
)
op_type = Sum if isinstance(node.op, Sum) else Prod
(node_inps,) = node.inputs
out_dtype = node.op.dtype
# This is done to make sure the rewrite doesn't affect other
# computations.
if len(fgraph.clients[node_inps]) == 1:
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None:
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
# figure out which axes were in the original sum
newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(
list(node_inps.owner.op.axis) + list(node.op.axis)
)
combined = op_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
combined = op_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [
......@@ -1669,11 +1650,7 @@ def local_reduce_join(fgraph, node):
where we join and reduce on the same set of axis.
"""
if (
isinstance(node.op, CAReduce)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Join)
):
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
join_node = node.inputs[0].owner
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return
......@@ -1732,11 +1709,10 @@ def local_reduce_join(fgraph, node):
@node_rewriter(ALL_REDUCE)
def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a"""
if isinstance(node.op, CAReduce):
(summed,) = node.inputs
# if reduce were doing anything, the output ndim would be reduced
if summed.type == node.outputs[0].type:
return [summed]
(summed,) = node.inputs
# if reduce were doing anything, the output ndim would be reduced
if summed.type == node.outputs[0].type:
return [summed]
@register_canonicalize
......@@ -1745,42 +1721,41 @@ def local_useless_reduce(fgraph, node):
@node_rewriter(ALL_REDUCE)
def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions."""
if isinstance(node.op, CAReduce):
(reduced,) = node.inputs
odtype = node.outputs[0].dtype
if node.op.axis is None:
if all(reduced.broadcastable):
return [reduced.dimshuffle().astype(odtype)]
else:
axis = list(node.op.axis)
cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable:
# -- we can remove some axes of summation.
new_axis = []
pattern = []
ii = 0
for p in range(reduced.ndim):
if p not in cuttable:
if p in axis:
new_axis.append(ii)
pattern.append(p)
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
if type(node.op) == CAReduce:
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
# TODO FIXME: This highlights a major design flaw in
# `CAReduce` (or at least our use of it), and it needs
# to be fixed
new_op = node.op.__class__(node.op.scalar_op, axis=new_axis)
else:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
(reduced,) = node.inputs
odtype = node.outputs[0].dtype
if node.op.axis is None:
if all(reduced.broadcastable):
return [reduced.dimshuffle().astype(odtype)]
else:
axis = list(node.op.axis)
cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable:
# -- we can remove some axes of summation.
new_axis = []
pattern = []
ii = 0
for p in range(reduced.ndim):
if p not in cuttable:
if p in axis:
new_axis.append(ii)
pattern.append(p)
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
if type(node.op) == CAReduce:
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
# TODO FIXME: This highlights a major design flaw in
# `CAReduce` (or at least our use of it), and it needs
# to be fixed
new_op = node.op.__class__(node.op.scalar_op, axis=new_axis)
else:
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
else:
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]
@register_specialize
......@@ -1792,61 +1767,54 @@ def local_opt_alloc(fgraph, node):
prod(alloc(constant,shapes...)) => constant**prod(shapes)
"""
if isinstance(node.op, Sum) or isinstance(node.op, Prod):
(node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
try:
val = get_underlying_scalar_constant_value(
inp, only_process_constants=True
)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
size = mul(*shapes)
if inp.dtype in ("float16", "float32"):
# shapes are ints and normally int64.
# We don't want to have a float64 upcast
# We don't want to downcast to float16
# as we fear it could loose too much precision
# that will be amplified by the mul/pow below.
size = size.astype("float32")
if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
if isinstance(node.op, Sum):
val = val * size
else:
val = val**size
# Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1
# elemwise we will do and not a sequence, so there is no
# accumulation of errors.
# So mostly, we just need to cast the output to the old
# dtype.
val = val.astype(node.outputs[0].dtype)
return [val]
to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis]
if to_prod:
size = mul(*to_prod)
if isinstance(node.op, Sum):
val *= size
else:
val = val**size
# See comments above.
(node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
try:
val = get_underlying_scalar_constant_value(inp, only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
size = mul(*shapes)
if inp.dtype in ("float16", "float32"):
# shapes are ints and normally int64.
# We don't want to have a float64 upcast
# We don't want to downcast to float16
# as we fear it could loose too much precision
# that will be amplified by the mul/pow below.
size = size.astype("float32")
if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
if isinstance(node.op, Sum):
val = val * size
else:
val = val**size
# Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1
# elemwise we will do and not a sequence, so there is no
# accumulation of errors.
# So mostly, we just need to cast the output to the old
# dtype.
val = val.astype(node.outputs[0].dtype)
return [
alloc(
val,
*[
shapes[i]
for i in range(len(shapes))
if i not in node.op.axis
],
)
]
except NotScalarConstantError:
pass
return [val]
to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis]
if to_prod:
size = mul(*to_prod)
if isinstance(node.op, Sum):
val *= size
else:
val = val**size
# See comments above.
val = val.astype(node.outputs[0].dtype)
return [
alloc(
val,
*[shapes[i] for i in range(len(shapes)) if i not in node.op.axis],
)
]
except NotScalarConstantError:
pass
@register_specialize
......@@ -1858,19 +1826,18 @@ def local_neg_div_neg(fgraph, node):
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
"""
if node.op == neg:
if node.inputs[0].owner and node.inputs[0].owner.op == true_div:
frac = node.inputs[0]
num, denom = frac.owner.inputs
if num.owner and num.owner.op == neg:
if len(fgraph.clients[frac]) == 1:
# No other clients of the original division
new_num = num.owner.inputs[0]
return [true_div(new_num, denom)]
elif all(num.broadcastable) and isinstance(num, Constant):
if len(fgraph.clients[frac]) == 1:
new_num = -num.data
return [true_div(new_num, denom)]
if node.inputs[0].owner and node.inputs[0].owner.op == true_div:
frac = node.inputs[0]
num, denom = frac.owner.inputs
if num.owner and num.owner.op == neg:
if len(fgraph.clients[frac]) == 1:
# No other clients of the original division
new_num = num.owner.inputs[0]
return [true_div(new_num, denom)]
elif all(num.broadcastable) and isinstance(num, Constant):
if len(fgraph.clients[frac]) == 1:
new_num = -num.data
return [true_div(new_num, denom)]
@register_canonicalize
......@@ -1881,14 +1848,13 @@ def local_sub_neg_to_add(fgraph, node):
x - (-y) -> x + y
"""
if node.op == sub:
minuend, subtrahend = node.inputs
minuend, subtrahend = node.inputs
if subtrahend.owner:
if subtrahend.owner.op == neg:
pre_neg = subtrahend.owner.inputs[0]
new_out = add(minuend, pre_neg)
return [new_out]
if subtrahend.owner:
if subtrahend.owner.op == neg:
pre_neg = subtrahend.owner.inputs[0]
new_out = add(minuend, pre_neg)
return [new_out]
@register_specialize
......@@ -1903,7 +1869,7 @@ def local_add_neg_to_sub(fgraph, node):
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2:
if len(node.inputs) == 2:
# Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)):
if second.owner:
......@@ -1927,27 +1893,24 @@ def local_mul_zero(fgraph, node):
with zero.
"""
if node.op == mul:
otype = node.outputs[0].type
otype = node.outputs[0].type
for i in node.inputs:
try:
value = get_underlying_scalar_constant_value(i)
except NotScalarConstantError:
continue
# print 'MUL by value', value, node.inputs
if value == 0:
# print '... returning zeros'
return [
broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]
]
for i in node.inputs:
try:
value = get_underlying_scalar_constant_value(i)
except NotScalarConstantError:
continue
# print 'MUL by value', value, node.inputs
if value == 0:
# print '... returning zeros'
return [broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]]
# TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize
@node_rewriter([true_div])
def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0):
if np.all(get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0]
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting
......@@ -1957,30 +1920,22 @@ def local_div_to_reciprocal(fgraph, node):
if not out.type.is_super(new_out.type):
new_out = alloc_like(new_out, out, fgraph)
return [new_out]
else:
return False
@register_canonicalize
@node_rewriter([reciprocal])
def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal:
return [pt_pow(node.inputs[0], -1.0)]
else:
return False
return [pt_pow(node.inputs[0], -1.0)]
@register_canonicalize
@node_rewriter([pt_pow])
def local_pow_canonicalize(fgraph, node):
if node.op == pt_pow:
cst = get_constant(node.inputs[1])
if cst == 0:
return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1:
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
else:
return False
cst = get_constant(node.inputs[1])
if cst == 0:
return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1:
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
@register_specialize
......@@ -1989,21 +1944,17 @@ def local_mul_to_sqr(fgraph, node):
"""
x*x -> sqr(x)
"""
if node.op == mul:
if len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
return [sqr(node.inputs[0])]
if len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
return [sqr(node.inputs[0])]
@register_canonicalize
@node_rewriter([int_div])
def local_intdiv_by_one(fgraph, node):
"""x // 1 -> x"""
if node.op in [int_div]:
if isinstance(node.inputs[1], TensorConstant) and np.all(
node.inputs[1].value == 1
):
return [node.inputs[0].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], TensorConstant) and np.all(node.inputs[1].value == 1):
return [node.inputs[0].astype(node.outputs[0].dtype)]
@register_canonicalize
......@@ -2011,49 +1962,43 @@ def local_intdiv_by_one(fgraph, node):
@node_rewriter([int_div, true_div])
def local_zero_div(fgraph, node):
"""0 / x -> 0"""
if isinstance(node.op, Elemwise) and isinstance(
node.op.scalar_op, ps.IntDiv | ps.TrueDiv
):
if get_constant(node.inputs[0]) == 0:
ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret]
if get_constant(node.inputs[0]) == 0:
ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret]
@register_specialize
@node_rewriter([pt_pow])
def local_pow_specialize(fgraph, node):
if node.op == pt_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
if np.all(y == 2):
rval = [sqr(xsym)]
if np.all(y == 1):
rval = [xsym]
if np.all(y == 0):
rval = [alloc_like(1, xsym, fgraph)]
if np.all(y == 0.5):
rval = [sqrt(xsym)]
if np.all(y == -0.5):
rval = [reciprocal(sqrt(xsym))]
if np.all(y == -1):
rval = [reciprocal(xsym)]
if np.all(y == -2):
rval = [reciprocal(sqr(xsym))]
if rval:
if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable:
return None
rval[0] = cast(rval[0], odtype)
assert rval[0].type.dtype == node.outputs[0].type.dtype
return rval
else:
return False
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
if np.all(y == 2):
rval = [sqr(xsym)]
if np.all(y == 1):
rval = [xsym]
if np.all(y == 0):
rval = [alloc_like(1, xsym, fgraph)]
if np.all(y == 0.5):
rval = [sqrt(xsym)]
if np.all(y == -0.5):
rval = [reciprocal(sqrt(xsym))]
if np.all(y == -1):
rval = [reciprocal(xsym)]
if np.all(y == -2):
rval = [reciprocal(sqr(xsym))]
if rval:
if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable:
return None
rval[0] = cast(rval[0], odtype)
assert rval[0].type.dtype == node.outputs[0].type.dtype
return rval
@register_specialize
......@@ -2138,61 +2083,60 @@ def local_mul_specialize(fgraph, node):
"""
# at this point [post canonicalize], mul() may have many inputs.
if node.op == mul:
# the idea here is that we have pow(x, y)
has_neg = False
new_inputs = []
nb_neg_node = 0
nb_cst = 0
for inp in node.inputs:
# remove any neg arguments
while inp.owner and inp.owner.op == neg:
has_neg ^= True
inp = inp.owner.inputs[0]
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0
y = get_constant(inp)
if y == 1.0:
nb_cst += 1
elif y == -1.0:
nb_cst += 1
has_neg ^= True # toggles
elif y == 0.0:
# if we find any zero, we just return right away
return [alloc_like(0, node.outputs[0], fgraph)]
else:
new_inputs.append(inp)
if new_inputs != node.inputs:
if new_inputs:
if len(new_inputs) == 1:
if has_neg:
if new_inputs[0].dtype in ([*uint_dtypes, "bool"]):
return
else:
rval = -new_inputs[0]
# the idea here is that we have pow(x, y)
has_neg = False
new_inputs = []
nb_neg_node = 0
nb_cst = 0
for inp in node.inputs:
# remove any neg arguments
while inp.owner and inp.owner.op == neg:
has_neg ^= True
inp = inp.owner.inputs[0]
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0
y = get_constant(inp)
if y == 1.0:
nb_cst += 1
elif y == -1.0:
nb_cst += 1
has_neg ^= True # toggles
elif y == 0.0:
# if we find any zero, we just return right away
return [alloc_like(0, node.outputs[0], fgraph)]
else:
new_inputs.append(inp)
if new_inputs != node.inputs:
if new_inputs:
if len(new_inputs) == 1:
if has_neg:
if new_inputs[0].dtype in ([*uint_dtypes, "bool"]):
return
else:
rval = new_inputs[0]
rval = -new_inputs[0]
else:
# The next case would cause a replace by an equivalent case.
if has_neg and nb_neg_node == 0 and nb_cst == 1:
return
elif has_neg:
# Don't add an extra neg node as we can't
# fully replace this mul by a neg.
m1 = np.asarray(-1, dtype=node.outputs[0].dtype)
new_inputs = [m1, *new_inputs]
rval = mul(*new_inputs)
return [alloc_like(rval, node.outputs[0], fgraph)]
rval = new_inputs[0]
else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if has_neg:
return [alloc_like(-1, node.outputs[0], fgraph)]
else:
return [alloc_like(1, node.outputs[0], fgraph)]
# The next case would cause a replace by an equivalent case.
if has_neg and nb_neg_node == 0 and nb_cst == 1:
return
elif has_neg:
# Don't add an extra neg node as we can't
# fully replace this mul by a neg.
m1 = np.asarray(-1, dtype=node.outputs[0].dtype)
new_inputs = [m1, *new_inputs]
rval = mul(*new_inputs)
return [alloc_like(rval, node.outputs[0], fgraph)]
else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if has_neg:
return [alloc_like(-1, node.outputs[0], fgraph)]
else:
return [alloc_like(1, node.outputs[0], fgraph)]
@register_specialize
......@@ -2276,7 +2220,7 @@ def local_abs_lift(fgraph, node):
This is needed for check_for_x_over_absX to apply in more case.
"""
if node.op == pt_abs and node.inputs[0].owner:
if node.inputs[0].owner:
assert node.nin == 1
if node.inputs[0].owner.op == mul:
return [mul(*[pt_abs(i) for i in node.inputs[0].owner.inputs])]
......@@ -2328,31 +2272,30 @@ def local_abs_merge(fgraph, node):
def local_log1p(fgraph, node):
# log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
if node.op == log:
(log_arg,) = node.inputs
if log_arg.owner and log_arg.owner.op == add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
log_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1):
if nonconsts:
if len(nonconsts) > 1:
ninp = add(*nonconsts)
else:
ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
if one != 1:
return
other = log_arg.owner.inputs[1]
if other.dtype != log_arg.dtype:
other = other.astype(log_arg.dtype)
return [log1p(neg(other))]
(log_arg,) = node.inputs
if log_arg.owner and log_arg.owner.op == add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
log_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1):
if nonconsts:
if len(nonconsts) > 1:
ninp = add(*nonconsts)
else:
ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
if one != 1:
return
other = log_arg.owner.inputs[1]
if other.dtype != log_arg.dtype:
other = other.astype(log_arg.dtype)
return [log1p(neg(other))]
@register_stabilize
......@@ -2365,26 +2308,25 @@ def local_log_add_exp(fgraph, node):
TODO: in canonicalize, change log10 and log2 -> log
"""
if node.op == log:
z = node.inputs[0]
if z.owner and z.owner.op == add:
zi = z.owner.inputs
pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
# all arguments to add are exp(<something>)
if len(pre_exp) == len(zi):
# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside add to break the self-symmetry
# of the returned output (otherwise the rewrite would not stabilize)
max_pre = reduce(maximum, pre_exp)
ret = max_pre + log(
add(
*[
switch(isinf(max_pre), exp(max_pre), exp(p - max_pre))
for p in pre_exp
]
)
z = node.inputs[0]
if z.owner and z.owner.op == add:
zi = z.owner.inputs
pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
# all arguments to add are exp(<something>)
if len(pre_exp) == len(zi):
# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside add to break the self-symmetry
# of the returned output (otherwise the rewrite would not stabilize)
max_pre = reduce(maximum, pre_exp)
ret = max_pre + log(
add(
*[
switch(isinf(max_pre), exp(max_pre), exp(p - max_pre))
for p in pre_exp
]
)
return [ret]
)
return [ret]
@register_stabilize
......@@ -2393,9 +2335,6 @@ def local_log_add_exp(fgraph, node):
def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
if node.op != log:
return
sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, DimShuffle):
......@@ -2720,8 +2659,7 @@ def local_log_erfc(fgraph, node):
numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
10.0541948,10.0541951,.0000001)]
"""
if node.op != log:
return False
if not node.inputs[0].owner or node.inputs[0].owner.op != erfc:
return False
......@@ -2773,8 +2711,6 @@ def local_grad_log_erfc_neg(fgraph, node):
Make it so that the test does not generate an error in that case!
"""
if node.op != true_div:
return False
if not node.inputs[1].owner or node.inputs[1].owner.op != erfc:
return False
......@@ -3147,46 +3083,45 @@ def local_exp_over_1_plus_exp(fgraph, node):
"""
# This rewrite should be done for numerical stability
# so we don't care to check client counts
if node.op == true_div:
# find all the exp() terms in the numerator
num, denom = node.inputs
num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp)
sigmoids = []
for t in denom_1pexp:
if t in num_exp_x:
# case: exp(x) /(1+exp(x))
sigmoids.append(sigmoid(t))
del num_exp_x[num_exp_x.index(t)]
else:
# case: 1/(1+exp(x))
sigmoids.append(sigmoid(-t))
copy_stack_trace(node.outputs[0], sigmoids[-1])
if not sigmoids: # we didn't find any. abort
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
# find all the exp() terms in the numerator
num, denom = node.inputs
num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp)
sigmoids = []
for t in denom_1pexp:
if t in num_exp_x:
# case: exp(x) /(1+exp(x))
sigmoids.append(sigmoid(t))
del num_exp_x[num_exp_x.index(t)]
else:
new_num = mul(*new_num)
# case: 1/(1+exp(x))
sigmoids.append(sigmoid(-t))
copy_stack_trace(node.outputs[0], sigmoids[-1])
if num_neg ^ denom_neg:
new_num = -new_num
if not sigmoids: # we didn't find any. abort
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else:
new_num = mul(*new_num)
copy_stack_trace(num, new_num)
if num_neg ^ denom_neg:
new_num = -new_num
if len(denom_rest) == 0:
return [new_num]
elif len(denom_rest) == 1:
out = new_num / denom_rest[0]
else:
out = new_num / mul(*denom_rest)
copy_stack_trace(num, new_num)
copy_stack_trace(node.outputs[0], out)
return [out]
if len(denom_rest) == 0:
return [new_num]
elif len(denom_rest) == 1:
out = new_num / denom_rest[0]
else:
out = new_num / mul(*denom_rest)
copy_stack_trace(node.outputs[0], out)
return [out]
def parse_mul_tree(root):
......@@ -3498,9 +3433,6 @@ def local_sigm_times_exp(fgraph, node):
todo: add stack traces to the intermediate variables
"""
# Bail early if it is not a multiplication.
if node.op != mul:
return None
# Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0])
did_something = perform_sigm_times_exp(mul_tree)
......@@ -3528,31 +3460,30 @@ def local_reciprocal_1_plus_exp(fgraph, node):
"""
# This Rewrite should be done for numerical stability
# so we don't care to check client counts
if node.op == reciprocal:
reciprocal_arg = node.inputs[0]
if reciprocal_arg.owner and reciprocal_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
reciprocal_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
out = [
alloc_like(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
node.outputs[0],
fgraph,
)
]
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace(
[nonconsts[0], reciprocal_arg, node.outputs[0]], out
reciprocal_arg = node.inputs[0]
if reciprocal_arg.owner and reciprocal_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
reciprocal_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
out = [
alloc_like(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
node.outputs[0],
fgraph,
)
return out
]
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace(
[nonconsts[0], reciprocal_arg, node.outputs[0]], out
)
return out
# 1 - sigmoid(x) -> sigmoid(-x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论