提交 639b0871 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Get rid of redundant checks in tracked node_rewriters

上级 cb2c40ba
......@@ -328,7 +328,7 @@ def local_func_inv(fgraph, node):
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([log, log1p, exp, expm1])
def local_exp_log(fgraph, node):
x = node.inputs[0]
......@@ -368,7 +368,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
x = node.inputs[0]
......@@ -431,11 +431,7 @@ def local_sumsqr2dot(fgraph, node):
``pt.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))``
and converts it to ``pt.dot(pt.sqr(G), pt.sqr(W).sum(axis=0))``.
"""
if (
isinstance(node.op, Sum)
and isinstance(node.op.scalar_op, ps.Add)
and node.op.axis == (1, 2)
):
if node.op.axis == (1, 2):
in1 = node.inputs[0]
out = node.outputs[0]
......@@ -479,7 +475,7 @@ def local_mul_exp_to_exp_add(fgraph, node):
n.owner.inputs[0]
for n in node.inputs
if n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op, Elemwise)
and isinstance(n.owner.op.scalar_op, ps.Exp)
]
# Can only do any rewrite if there are at least two exp-s
......@@ -523,7 +519,7 @@ def local_mul_pow_to_pow_add(fgraph, node):
for n in node.inputs:
if (
n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op, Elemwise)
and isinstance(n.owner.op.scalar_op, ps.Pow)
):
base_node = n.owner.inputs[0]
......@@ -567,10 +563,9 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize
@register_specialize
@register_canonicalize
@node_rewriter([Elemwise])
@node_rewriter([sub])
def local_expm1(fgraph, node):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Sub):
in1, in2 = node.inputs
out = node.outputs[0]
......@@ -625,8 +620,6 @@ def local_mul_switch_sink(fgraph, node):
part of the graph.
"""
if node.op != mul:
return False
for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == switch:
switch_node = i.owner
......@@ -705,8 +698,6 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details.
"""
if node.op != true_div and node.op != int_div:
return False
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner
......@@ -1235,7 +1226,6 @@ register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canon
@register_canonicalize
@node_rewriter([neg])
def local_neg_to_mul(fgraph, node):
if node.op == neg:
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
......@@ -1347,17 +1337,12 @@ def local_sum_of_neg_to_neg_of_sum(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
@node_rewriter([sub])
def local_elemwise_sub_zeros(fgraph, node):
"""
Elemwise{sub}(X,X) -> zeros_like(X)
"""
if (
isinstance(node.op, Elemwise)
and node.op.scalar_op.nin == 2
and node.op.scalar_op == ps.sub
and node.inputs[0] == node.inputs[1]
):
if node.inputs[0] == node.inputs[1]:
res = zeros_like(node.inputs[0])
# Copy over stacktrace from previous output.
# This could help for failures due to out-of-memory.
......@@ -1400,8 +1385,6 @@ def local_useless_elemwise_comparison(fgraph, node):
the graph easier to read.
"""
if not isinstance(node.op, Elemwise):
return
if node.op.scalar_op.nin != 2:
return
......@@ -1590,7 +1573,6 @@ def local_sum_prod_all_to_none(fgraph, node):
Prod{0,1,...N} -> Prod{}
"""
if isinstance(node.op, Sum) or isinstance(node.op, Prod):
op_type = Sum if isinstance(node.op, Sum) else Prod
# if all the axes are named, then use None as a shorthand
# this permits more merging
......@@ -1609,7 +1591,6 @@ def local_op_of_op(fgraph, node):
Sum(Sum()) -> single Sum()
"""
if isinstance(node.op, Prod) or isinstance(node.op, Sum):
op_type = Sum if isinstance(node.op, Sum) else Prod
(node_inps,) = node.inputs
out_dtype = node.op.dtype
......@@ -1669,11 +1650,7 @@ def local_reduce_join(fgraph, node):
where we join and reduce on the same set of axis.
"""
if (
isinstance(node.op, CAReduce)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Join)
):
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
join_node = node.inputs[0].owner
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return
......@@ -1732,7 +1709,6 @@ def local_reduce_join(fgraph, node):
@node_rewriter(ALL_REDUCE)
def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a"""
if isinstance(node.op, CAReduce):
(summed,) = node.inputs
# if reduce were doing anything, the output ndim would be reduced
if summed.type == node.outputs[0].type:
......@@ -1745,7 +1721,6 @@ def local_useless_reduce(fgraph, node):
@node_rewriter(ALL_REDUCE)
def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions."""
if isinstance(node.op, CAReduce):
(reduced,) = node.inputs
odtype = node.outputs[0].dtype
if node.op.axis is None:
......@@ -1792,15 +1767,12 @@ def local_opt_alloc(fgraph, node):
prod(alloc(constant,shapes...)) => constant**prod(shapes)
"""
if isinstance(node.op, Sum) or isinstance(node.op, Prod):
(node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
try:
val = get_underlying_scalar_constant_value(
inp, only_process_constants=True
)
val = get_underlying_scalar_constant_value(inp, only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
......@@ -1838,11 +1810,7 @@ def local_opt_alloc(fgraph, node):
return [
alloc(
val,
*[
shapes[i]
for i in range(len(shapes))
if i not in node.op.axis
],
*[shapes[i] for i in range(len(shapes)) if i not in node.op.axis],
)
]
except NotScalarConstantError:
......@@ -1858,7 +1826,6 @@ def local_neg_div_neg(fgraph, node):
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
"""
if node.op == neg:
if node.inputs[0].owner and node.inputs[0].owner.op == true_div:
frac = node.inputs[0]
num, denom = frac.owner.inputs
......@@ -1881,7 +1848,6 @@ def local_sub_neg_to_add(fgraph, node):
x - (-y) -> x + y
"""
if node.op == sub:
minuend, subtrahend = node.inputs
if subtrahend.owner:
......@@ -1903,7 +1869,7 @@ def local_add_neg_to_sub(fgraph, node):
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2:
if len(node.inputs) == 2:
# Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)):
if second.owner:
......@@ -1927,7 +1893,6 @@ def local_mul_zero(fgraph, node):
with zero.
"""
if node.op == mul:
otype = node.outputs[0].type
for i in node.inputs:
......@@ -1938,16 +1903,14 @@ def local_mul_zero(fgraph, node):
# print 'MUL by value', value, node.inputs
if value == 0:
# print '... returning zeros'
return [
broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]
]
return [broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]]
# TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize
@node_rewriter([true_div])
def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0):
if np.all(get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0]
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting
......@@ -1957,30 +1920,22 @@ def local_div_to_reciprocal(fgraph, node):
if not out.type.is_super(new_out.type):
new_out = alloc_like(new_out, out, fgraph)
return [new_out]
else:
return False
@register_canonicalize
@node_rewriter([reciprocal])
def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal:
return [pt_pow(node.inputs[0], -1.0)]
else:
return False
@register_canonicalize
@node_rewriter([pt_pow])
def local_pow_canonicalize(fgraph, node):
if node.op == pt_pow:
cst = get_constant(node.inputs[1])
if cst == 0:
return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1:
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
else:
return False
@register_specialize
......@@ -1989,7 +1944,6 @@ def local_mul_to_sqr(fgraph, node):
"""
x*x -> sqr(x)
"""
if node.op == mul:
if len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
return [sqr(node.inputs[0])]
......@@ -1999,10 +1953,7 @@ def local_mul_to_sqr(fgraph, node):
@node_rewriter([int_div])
def local_intdiv_by_one(fgraph, node):
"""x // 1 -> x"""
if node.op in [int_div]:
if isinstance(node.inputs[1], TensorConstant) and np.all(
node.inputs[1].value == 1
):
if isinstance(node.inputs[1], TensorConstant) and np.all(node.inputs[1].value == 1):
return [node.inputs[0].astype(node.outputs[0].dtype)]
......@@ -2011,9 +1962,6 @@ def local_intdiv_by_one(fgraph, node):
@node_rewriter([int_div, true_div])
def local_zero_div(fgraph, node):
"""0 / x -> 0"""
if isinstance(node.op, Elemwise) and isinstance(
node.op.scalar_op, ps.IntDiv | ps.TrueDiv
):
if get_constant(node.inputs[0]) == 0:
ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
......@@ -2023,7 +1971,6 @@ def local_zero_div(fgraph, node):
@register_specialize
@node_rewriter([pt_pow])
def local_pow_specialize(fgraph, node):
if node.op == pt_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
......@@ -2052,8 +1999,6 @@ def local_pow_specialize(fgraph, node):
rval[0] = cast(rval[0], odtype)
assert rval[0].type.dtype == node.outputs[0].type.dtype
return rval
else:
return False
@register_specialize
......@@ -2138,7 +2083,6 @@ def local_mul_specialize(fgraph, node):
"""
# at this point [post canonicalize], mul() may have many inputs.
if node.op == mul:
# the idea here is that we have pow(x, y)
has_neg = False
new_inputs = []
......@@ -2276,7 +2220,7 @@ def local_abs_lift(fgraph, node):
This is needed for check_for_x_over_absX to apply in more case.
"""
if node.op == pt_abs and node.inputs[0].owner:
if node.inputs[0].owner:
assert node.nin == 1
if node.inputs[0].owner.op == mul:
return [mul(*[pt_abs(i) for i in node.inputs[0].owner.inputs])]
......@@ -2328,7 +2272,6 @@ def local_abs_merge(fgraph, node):
def local_log1p(fgraph, node):
# log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
if node.op == log:
(log_arg,) = node.inputs
if log_arg.owner and log_arg.owner.op == add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
......@@ -2365,7 +2308,6 @@ def local_log_add_exp(fgraph, node):
TODO: in canonicalize, change log10 and log2 -> log
"""
if node.op == log:
z = node.inputs[0]
if z.owner and z.owner.op == add:
zi = z.owner.inputs
......@@ -2393,9 +2335,6 @@ def local_log_add_exp(fgraph, node):
def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
if node.op != log:
return
sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, DimShuffle):
......@@ -2720,8 +2659,7 @@ def local_log_erfc(fgraph, node):
numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
10.0541948,10.0541951,.0000001)]
"""
if node.op != log:
return False
if not node.inputs[0].owner or node.inputs[0].owner.op != erfc:
return False
......@@ -2773,8 +2711,6 @@ def local_grad_log_erfc_neg(fgraph, node):
Make it so that the test does not generate an error in that case!
"""
if node.op != true_div:
return False
if not node.inputs[1].owner or node.inputs[1].owner.op != erfc:
return False
......@@ -3147,7 +3083,6 @@ def local_exp_over_1_plus_exp(fgraph, node):
"""
# This rewrite should be done for numerical stability
# so we don't care to check client counts
if node.op == true_div:
# find all the exp() terms in the numerator
num, denom = node.inputs
num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
......@@ -3498,9 +3433,6 @@ def local_sigm_times_exp(fgraph, node):
todo: add stack traces to the intermediate variables
"""
# Bail early if it is not a multiplication.
if node.op != mul:
return None
# Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0])
did_something = perform_sigm_times_exp(mul_tree)
......@@ -3528,7 +3460,6 @@ def local_reciprocal_1_plus_exp(fgraph, node):
"""
# This Rewrite should be done for numerical stability
# so we don't care to check client counts
if node.op == reciprocal:
reciprocal_arg = node.inputs[0]
if reciprocal_arg.owner and reciprocal_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论