提交 639b0871 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Get rid of redundant checks in tracked node_rewriters

上级 cb2c40ba
...@@ -328,7 +328,7 @@ def local_func_inv(fgraph, node): ...@@ -328,7 +328,7 @@ def local_func_inv(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([log, log1p, exp, expm1])
def local_exp_log(fgraph, node): def local_exp_log(fgraph, node):
x = node.inputs[0] x = node.inputs[0]
...@@ -368,7 +368,7 @@ def local_exp_log(fgraph, node): ...@@ -368,7 +368,7 @@ def local_exp_log(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node): def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch # Rewrites of the kind exp(log...(x)) that require a `nan` switch
x = node.inputs[0] x = node.inputs[0]
...@@ -431,11 +431,7 @@ def local_sumsqr2dot(fgraph, node): ...@@ -431,11 +431,7 @@ def local_sumsqr2dot(fgraph, node):
``pt.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))`` ``pt.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))``
and converts it to ``pt.dot(pt.sqr(G), pt.sqr(W).sum(axis=0))``. and converts it to ``pt.dot(pt.sqr(G), pt.sqr(W).sum(axis=0))``.
""" """
if ( if node.op.axis == (1, 2):
isinstance(node.op, Sum)
and isinstance(node.op.scalar_op, ps.Add)
and node.op.axis == (1, 2)
):
in1 = node.inputs[0] in1 = node.inputs[0]
out = node.outputs[0] out = node.outputs[0]
...@@ -479,7 +475,7 @@ def local_mul_exp_to_exp_add(fgraph, node): ...@@ -479,7 +475,7 @@ def local_mul_exp_to_exp_add(fgraph, node):
n.owner.inputs[0] n.owner.inputs[0]
for n in node.inputs for n in node.inputs
if n.owner if n.owner
and hasattr(n.owner.op, "scalar_op") and isinstance(n.owner.op, Elemwise)
and isinstance(n.owner.op.scalar_op, ps.Exp) and isinstance(n.owner.op.scalar_op, ps.Exp)
] ]
# Can only do any rewrite if there are at least two exp-s # Can only do any rewrite if there are at least two exp-s
...@@ -523,7 +519,7 @@ def local_mul_pow_to_pow_add(fgraph, node): ...@@ -523,7 +519,7 @@ def local_mul_pow_to_pow_add(fgraph, node):
for n in node.inputs: for n in node.inputs:
if ( if (
n.owner n.owner
and hasattr(n.owner.op, "scalar_op") and isinstance(n.owner.op, Elemwise)
and isinstance(n.owner.op.scalar_op, ps.Pow) and isinstance(n.owner.op.scalar_op, ps.Pow)
): ):
base_node = n.owner.inputs[0] base_node = n.owner.inputs[0]
...@@ -567,28 +563,27 @@ def local_mul_pow_to_pow_add(fgraph, node): ...@@ -567,28 +563,27 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@node_rewriter([Elemwise]) @node_rewriter([sub])
def local_expm1(fgraph, node): def local_expm1(fgraph, node):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``.""" """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Sub): in1, in2 = node.inputs
in1, in2 = node.inputs out = node.outputs[0]
out = node.outputs[0]
if ( if (
in1.owner in1.owner
and isinstance(in1.owner.op, Elemwise) and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp) and isinstance(in1.owner.op.scalar_op, ps.Exp)
and extract_constant(in2, only_process_constants=False) == 1 and extract_constant(in2, only_process_constants=False) == 1
): ):
in11 = in1.owner.inputs[0] in11 = in1.owner.inputs[0]
new_out = expm1(in11) new_out = expm1(in11)
if new_out.dtype != out.dtype: if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype) new_out = cast(new_out, dtype=out.dtype)
if not out.type.is_super(new_out.type): if not out.type.is_super(new_out.type):
return return
return [new_out] return [new_out]
@register_specialize @register_specialize
...@@ -625,8 +620,6 @@ def local_mul_switch_sink(fgraph, node): ...@@ -625,8 +620,6 @@ def local_mul_switch_sink(fgraph, node):
part of the graph. part of the graph.
""" """
if node.op != mul:
return False
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == switch: if i.owner and i.owner.op == switch:
switch_node = i.owner switch_node = i.owner
...@@ -705,8 +698,6 @@ def local_div_switch_sink(fgraph, node): ...@@ -705,8 +698,6 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details. See `local_mul_switch_sink` for more details.
""" """
if node.op != true_div and node.op != int_div:
return False
op = node.op op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == switch: if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner switch_node = node.inputs[0].owner
...@@ -1235,8 +1226,7 @@ register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canon ...@@ -1235,8 +1226,7 @@ register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canon
@register_canonicalize @register_canonicalize
@node_rewriter([neg]) @node_rewriter([neg])
def local_neg_to_mul(fgraph, node): def local_neg_to_mul(fgraph, node):
if node.op == neg: return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])]
@register_specialize @register_specialize
...@@ -1347,17 +1337,12 @@ def local_sum_of_neg_to_neg_of_sum(fgraph, node): ...@@ -1347,17 +1337,12 @@ def local_sum_of_neg_to_neg_of_sum(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([sub])
def local_elemwise_sub_zeros(fgraph, node): def local_elemwise_sub_zeros(fgraph, node):
""" """
Elemwise{sub}(X,X) -> zeros_like(X) Elemwise{sub}(X,X) -> zeros_like(X)
""" """
if ( if node.inputs[0] == node.inputs[1]:
isinstance(node.op, Elemwise)
and node.op.scalar_op.nin == 2
and node.op.scalar_op == ps.sub
and node.inputs[0] == node.inputs[1]
):
res = zeros_like(node.inputs[0]) res = zeros_like(node.inputs[0])
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
# This could help for failures due to out-of-memory. # This could help for failures due to out-of-memory.
...@@ -1400,8 +1385,6 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1400,8 +1385,6 @@ def local_useless_elemwise_comparison(fgraph, node):
the graph easier to read. the graph easier to read.
""" """
if not isinstance(node.op, Elemwise):
return
if node.op.scalar_op.nin != 2: if node.op.scalar_op.nin != 2:
return return
...@@ -1590,14 +1573,13 @@ def local_sum_prod_all_to_none(fgraph, node): ...@@ -1590,14 +1573,13 @@ def local_sum_prod_all_to_none(fgraph, node):
Prod{0,1,...N} -> Prod{} Prod{0,1,...N} -> Prod{}
""" """
if isinstance(node.op, Sum) or isinstance(node.op, Prod): op_type = Sum if isinstance(node.op, Sum) else Prod
op_type = Sum if isinstance(node.op, Sum) else Prod # if all the axes are named, then use None as a shorthand
# if all the axes are named, then use None as a shorthand # this permits more merging
# this permits more merging if node.op.axis is None:
if node.op.axis is None: return
return if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): return [op_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
return [op_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize @register_canonicalize
...@@ -1609,35 +1591,34 @@ def local_op_of_op(fgraph, node): ...@@ -1609,35 +1591,34 @@ def local_op_of_op(fgraph, node):
Sum(Sum()) -> single Sum() Sum(Sum()) -> single Sum()
""" """
if isinstance(node.op, Prod) or isinstance(node.op, Sum): op_type = Sum if isinstance(node.op, Sum) else Prod
op_type = Sum if isinstance(node.op, Sum) else Prod (node_inps,) = node.inputs
(node_inps,) = node.inputs out_dtype = node.op.dtype
out_dtype = node.op.dtype # This is done to make sure the rewrite doesn't affect other
# This is done to make sure the rewrite doesn't affect other # computations.
# computations. if len(fgraph.clients[node_inps]) == 1:
if len(fgraph.clients[node_inps]) == 1: if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)): # check to see either the inner or outer prod is doing a
# check to see either the inner or outer prod is doing a # product over all axis, in which case we can remove it
# product over all axis, in which case we can remove it if node_inps.owner.op.axis is None or node.op.axis is None:
if node_inps.owner.op.axis is None or node.op.axis is None: return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
# figure out which axes were in the original sum
# figure out which axes were in the original sum newaxis = list(node_inps.owner.op.axis)
newaxis = list(node_inps.owner.op.axis) for i in node.op.axis:
for i in node.op.axis: new_i = i
new_i = i for ii in node_inps.owner.op.axis:
for ii in node_inps.owner.op.axis: if new_i >= ii:
if new_i >= ii: new_i += 1
new_i += 1 assert new_i not in newaxis
assert new_i not in newaxis newaxis.append(new_i)
newaxis.append(new_i)
assert len(newaxis) == len(
assert len(newaxis) == len( list(node_inps.owner.op.axis) + list(node.op.axis)
list(node_inps.owner.op.axis) + list(node.op.axis) )
)
combined = op_type(newaxis, dtype=out_dtype) combined = op_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])] return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [ ALL_REDUCE = [
...@@ -1669,11 +1650,7 @@ def local_reduce_join(fgraph, node): ...@@ -1669,11 +1650,7 @@ def local_reduce_join(fgraph, node):
where we join and reduce on the same set of axis. where we join and reduce on the same set of axis.
""" """
if ( if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
isinstance(node.op, CAReduce)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Join)
):
join_node = node.inputs[0].owner join_node = node.inputs[0].owner
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0: if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return return
...@@ -1732,11 +1709,10 @@ def local_reduce_join(fgraph, node): ...@@ -1732,11 +1709,10 @@ def local_reduce_join(fgraph, node):
@node_rewriter(ALL_REDUCE) @node_rewriter(ALL_REDUCE)
def local_useless_reduce(fgraph, node): def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a""" """Sum(a, axis=[]) -> a"""
if isinstance(node.op, CAReduce): (summed,) = node.inputs
(summed,) = node.inputs # if reduce were doing anything, the output ndim would be reduced
# if reduce were doing anything, the output ndim would be reduced if summed.type == node.outputs[0].type:
if summed.type == node.outputs[0].type: return [summed]
return [summed]
@register_canonicalize @register_canonicalize
...@@ -1745,42 +1721,41 @@ def local_useless_reduce(fgraph, node): ...@@ -1745,42 +1721,41 @@ def local_useless_reduce(fgraph, node):
@node_rewriter(ALL_REDUCE) @node_rewriter(ALL_REDUCE)
def local_reduce_broadcastable(fgraph, node): def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions.""" """Remove reduction over broadcastable dimensions."""
if isinstance(node.op, CAReduce): (reduced,) = node.inputs
(reduced,) = node.inputs odtype = node.outputs[0].dtype
odtype = node.outputs[0].dtype if node.op.axis is None:
if node.op.axis is None: if all(reduced.broadcastable):
if all(reduced.broadcastable): return [reduced.dimshuffle().astype(odtype)]
return [reduced.dimshuffle().astype(odtype)] else:
else: axis = list(node.op.axis)
axis = list(node.op.axis) cuttable = [a for a in axis if reduced.broadcastable[a]]
cuttable = [a for a in axis if reduced.broadcastable[a]] if cuttable:
if cuttable: # -- we can remove some axes of summation.
# -- we can remove some axes of summation. new_axis = []
new_axis = [] pattern = []
pattern = [] ii = 0
ii = 0 for p in range(reduced.ndim):
for p in range(reduced.ndim): if p not in cuttable:
if p not in cuttable: if p in axis:
if p in axis: new_axis.append(ii)
new_axis.append(ii) pattern.append(p)
pattern.append(p) ii += 1
ii += 1 new_reduced = reduced.dimshuffle(*pattern)
new_reduced = reduced.dimshuffle(*pattern) if new_axis:
if new_axis: if type(node.op) == CAReduce:
if type(node.op) == CAReduce: # This case handles `CAReduce` instances
# This case handles `CAReduce` instances # (e.g. generated by `scalar_elemwise`), and not the
# (e.g. generated by `scalar_elemwise`), and not the # scalar `Op`-specific subclasses
# scalar `Op`-specific subclasses # TODO FIXME: This highlights a major design flaw in
# TODO FIXME: This highlights a major design flaw in # `CAReduce` (or at least our use of it), and it needs
# `CAReduce` (or at least our use of it), and it needs # to be fixed
# to be fixed new_op = node.op.__class__(node.op.scalar_op, axis=new_axis)
new_op = node.op.__class__(node.op.scalar_op, axis=new_axis)
else:
new_op = node.op.__class__(axis=new_axis)
return [new_op(new_reduced)]
else: else:
# -- in this case we can remove the reduction completely new_op = node.op.__class__(axis=new_axis)
return [new_reduced.astype(odtype)] return [new_op(new_reduced)]
else:
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]
@register_specialize @register_specialize
...@@ -1792,61 +1767,54 @@ def local_opt_alloc(fgraph, node): ...@@ -1792,61 +1767,54 @@ def local_opt_alloc(fgraph, node):
prod(alloc(constant,shapes...)) => constant**prod(shapes) prod(alloc(constant,shapes...)) => constant**prod(shapes)
""" """
if isinstance(node.op, Sum) or isinstance(node.op, Prod): (node_inps,) = node.inputs
(node_inps,) = node.inputs if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
if node_inps.owner and isinstance(node_inps.owner.op, Alloc): inp = node_inps.owner.inputs[0]
inp = node_inps.owner.inputs[0] shapes = node_inps.owner.inputs[1:]
shapes = node_inps.owner.inputs[1:] try:
try: val = get_underlying_scalar_constant_value(inp, only_process_constants=True)
val = get_underlying_scalar_constant_value( assert val.size == 1
inp, only_process_constants=True val = val.reshape(1)[0]
) # check which type of op
assert val.size == 1 size = mul(*shapes)
val = val.reshape(1)[0] if inp.dtype in ("float16", "float32"):
# check which type of op # shapes are ints and normally int64.
size = mul(*shapes) # We don't want to have a float64 upcast
if inp.dtype in ("float16", "float32"): # We don't want to downcast to float16
# shapes are ints and normally int64. # as we fear it could loose too much precision
# We don't want to have a float64 upcast # that will be amplified by the mul/pow below.
# We don't want to downcast to float16 size = size.astype("float32")
# as we fear it could loose too much precision if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
# that will be amplified by the mul/pow below. if isinstance(node.op, Sum):
size = size.astype("float32") val = val * size
if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)): else:
if isinstance(node.op, Sum): val = val**size
val = val * size # Sum can change the input dtype (upcast or bool
else: # -> float32) by default or by user request.
val = val**size # We can ignore the acc_dtype, as there is only 1
# Sum can change the input dtype (upcast or bool # elemwise we will do and not a sequence, so there is no
# -> float32) by default or by user request. # accumulation of errors.
# We can ignore the acc_dtype, as there is only 1 # So mostly, we just need to cast the output to the old
# elemwise we will do and not a sequence, so there is no # dtype.
# accumulation of errors.
# So mostly, we just need to cast the output to the old
# dtype.
val = val.astype(node.outputs[0].dtype)
return [val]
to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis]
if to_prod:
size = mul(*to_prod)
if isinstance(node.op, Sum):
val *= size
else:
val = val**size
# See comments above.
val = val.astype(node.outputs[0].dtype) val = val.astype(node.outputs[0].dtype)
return [ return [val]
alloc( to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis]
val, if to_prod:
*[ size = mul(*to_prod)
shapes[i] if isinstance(node.op, Sum):
for i in range(len(shapes)) val *= size
if i not in node.op.axis else:
], val = val**size
) # See comments above.
] val = val.astype(node.outputs[0].dtype)
except NotScalarConstantError: return [
pass alloc(
val,
*[shapes[i] for i in range(len(shapes)) if i not in node.op.axis],
)
]
except NotScalarConstantError:
pass
@register_specialize @register_specialize
...@@ -1858,19 +1826,18 @@ def local_neg_div_neg(fgraph, node): ...@@ -1858,19 +1826,18 @@ def local_neg_div_neg(fgraph, node):
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant. Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
""" """
if node.op == neg: if node.inputs[0].owner and node.inputs[0].owner.op == true_div:
if node.inputs[0].owner and node.inputs[0].owner.op == true_div: frac = node.inputs[0]
frac = node.inputs[0] num, denom = frac.owner.inputs
num, denom = frac.owner.inputs if num.owner and num.owner.op == neg:
if num.owner and num.owner.op == neg: if len(fgraph.clients[frac]) == 1:
if len(fgraph.clients[frac]) == 1: # No other clients of the original division
# No other clients of the original division new_num = num.owner.inputs[0]
new_num = num.owner.inputs[0] return [true_div(new_num, denom)]
return [true_div(new_num, denom)] elif all(num.broadcastable) and isinstance(num, Constant):
elif all(num.broadcastable) and isinstance(num, Constant): if len(fgraph.clients[frac]) == 1:
if len(fgraph.clients[frac]) == 1: new_num = -num.data
new_num = -num.data return [true_div(new_num, denom)]
return [true_div(new_num, denom)]
@register_canonicalize @register_canonicalize
...@@ -1881,14 +1848,13 @@ def local_sub_neg_to_add(fgraph, node): ...@@ -1881,14 +1848,13 @@ def local_sub_neg_to_add(fgraph, node):
x - (-y) -> x + y x - (-y) -> x + y
""" """
if node.op == sub: minuend, subtrahend = node.inputs
minuend, subtrahend = node.inputs
if subtrahend.owner: if subtrahend.owner:
if subtrahend.owner.op == neg: if subtrahend.owner.op == neg:
pre_neg = subtrahend.owner.inputs[0] pre_neg = subtrahend.owner.inputs[0]
new_out = add(minuend, pre_neg) new_out = add(minuend, pre_neg)
return [new_out] return [new_out]
@register_specialize @register_specialize
...@@ -1903,7 +1869,7 @@ def local_add_neg_to_sub(fgraph, node): ...@@ -1903,7 +1869,7 @@ def local_add_neg_to_sub(fgraph, node):
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization # `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# Rewrite is only applicable when there are two inputs to add # Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2: if len(node.inputs) == 2:
# Look for pattern with either input order # Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)): for first, second in (node.inputs, reversed(node.inputs)):
if second.owner: if second.owner:
...@@ -1927,27 +1893,24 @@ def local_mul_zero(fgraph, node): ...@@ -1927,27 +1893,24 @@ def local_mul_zero(fgraph, node):
with zero. with zero.
""" """
if node.op == mul: otype = node.outputs[0].type
otype = node.outputs[0].type
for i in node.inputs: for i in node.inputs:
try: try:
value = get_underlying_scalar_constant_value(i) value = get_underlying_scalar_constant_value(i)
except NotScalarConstantError: except NotScalarConstantError:
continue continue
# print 'MUL by value', value, node.inputs # print 'MUL by value', value, node.inputs
if value == 0: if value == 0:
# print '... returning zeros' # print '... returning zeros'
return [ return [broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]]
broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]
]
# TODO: Add this to the canonicalization to reduce redundancy. # TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize @register_specialize
@node_rewriter([true_div]) @node_rewriter([true_div])
def local_div_to_reciprocal(fgraph, node): def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0): if np.all(get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0] out = node.outputs[0]
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting # The ones could have forced upcasting
...@@ -1957,30 +1920,22 @@ def local_div_to_reciprocal(fgraph, node): ...@@ -1957,30 +1920,22 @@ def local_div_to_reciprocal(fgraph, node):
if not out.type.is_super(new_out.type): if not out.type.is_super(new_out.type):
new_out = alloc_like(new_out, out, fgraph) new_out = alloc_like(new_out, out, fgraph)
return [new_out] return [new_out]
else:
return False
@register_canonicalize @register_canonicalize
@node_rewriter([reciprocal]) @node_rewriter([reciprocal])
def local_reciprocal_canon(fgraph, node): def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal: return [pt_pow(node.inputs[0], -1.0)]
return [pt_pow(node.inputs[0], -1.0)]
else:
return False
@register_canonicalize @register_canonicalize
@node_rewriter([pt_pow]) @node_rewriter([pt_pow])
def local_pow_canonicalize(fgraph, node): def local_pow_canonicalize(fgraph, node):
if node.op == pt_pow: cst = get_constant(node.inputs[1])
cst = get_constant(node.inputs[1]) if cst == 0:
if cst == 0: return [alloc_like(1, node.outputs[0], fgraph)]
return [alloc_like(1, node.outputs[0], fgraph)] if cst == 1:
if cst == 1: return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
else:
return False
@register_specialize @register_specialize
...@@ -1989,21 +1944,17 @@ def local_mul_to_sqr(fgraph, node): ...@@ -1989,21 +1944,17 @@ def local_mul_to_sqr(fgraph, node):
""" """
x*x -> sqr(x) x*x -> sqr(x)
""" """
if node.op == mul: if len(node.inputs) == 2:
if len(node.inputs) == 2: if node.inputs[0] is node.inputs[1]:
if node.inputs[0] is node.inputs[1]: return [sqr(node.inputs[0])]
return [sqr(node.inputs[0])]
@register_canonicalize @register_canonicalize
@node_rewriter([int_div]) @node_rewriter([int_div])
def local_intdiv_by_one(fgraph, node): def local_intdiv_by_one(fgraph, node):
"""x // 1 -> x""" """x // 1 -> x"""
if node.op in [int_div]: if isinstance(node.inputs[1], TensorConstant) and np.all(node.inputs[1].value == 1):
if isinstance(node.inputs[1], TensorConstant) and np.all( return [node.inputs[0].astype(node.outputs[0].dtype)]
node.inputs[1].value == 1
):
return [node.inputs[0].astype(node.outputs[0].dtype)]
@register_canonicalize @register_canonicalize
...@@ -2011,49 +1962,43 @@ def local_intdiv_by_one(fgraph, node): ...@@ -2011,49 +1962,43 @@ def local_intdiv_by_one(fgraph, node):
@node_rewriter([int_div, true_div]) @node_rewriter([int_div, true_div])
def local_zero_div(fgraph, node): def local_zero_div(fgraph, node):
"""0 / x -> 0""" """0 / x -> 0"""
if isinstance(node.op, Elemwise) and isinstance( if get_constant(node.inputs[0]) == 0:
node.op.scalar_op, ps.IntDiv | ps.TrueDiv ret = alloc_like(0, node.outputs[0], fgraph)
): ret.tag.values_eq_approx = values_eq_approx_remove_nan
if get_constant(node.inputs[0]) == 0: return [ret]
ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret]
@register_specialize @register_specialize
@node_rewriter([pt_pow]) @node_rewriter([pt_pow])
def local_pow_specialize(fgraph, node): def local_pow_specialize(fgraph, node):
if node.op == pt_pow: # the idea here is that we have pow(x, y)
# the idea here is that we have pow(x, y) odtype = node.outputs[0].dtype
odtype = node.outputs[0].dtype xsym = node.inputs[0]
xsym = node.inputs[0] ysym = node.inputs[1]
ysym = node.inputs[1] y = get_constant(ysym)
y = get_constant(ysym) if (y is not None) and not broadcasted_by(xsym, ysym):
if (y is not None) and not broadcasted_by(xsym, ysym): rval = None
rval = None
if np.all(y == 2):
if np.all(y == 2): rval = [sqr(xsym)]
rval = [sqr(xsym)] if np.all(y == 1):
if np.all(y == 1): rval = [xsym]
rval = [xsym] if np.all(y == 0):
if np.all(y == 0): rval = [alloc_like(1, xsym, fgraph)]
rval = [alloc_like(1, xsym, fgraph)] if np.all(y == 0.5):
if np.all(y == 0.5): rval = [sqrt(xsym)]
rval = [sqrt(xsym)] if np.all(y == -0.5):
if np.all(y == -0.5): rval = [reciprocal(sqrt(xsym))]
rval = [reciprocal(sqrt(xsym))] if np.all(y == -1):
if np.all(y == -1): rval = [reciprocal(xsym)]
rval = [reciprocal(xsym)] if np.all(y == -2):
if np.all(y == -2): rval = [reciprocal(sqr(xsym))]
rval = [reciprocal(sqr(xsym))] if rval:
if rval: if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable:
if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable: return None
return None rval[0] = cast(rval[0], odtype)
rval[0] = cast(rval[0], odtype) assert rval[0].type.dtype == node.outputs[0].type.dtype
assert rval[0].type.dtype == node.outputs[0].type.dtype return rval
return rval
else:
return False
@register_specialize @register_specialize
...@@ -2138,61 +2083,60 @@ def local_mul_specialize(fgraph, node): ...@@ -2138,61 +2083,60 @@ def local_mul_specialize(fgraph, node):
""" """
# at this point [post canonicalize], mul() may have many inputs. # at this point [post canonicalize], mul() may have many inputs.
if node.op == mul: # the idea here is that we have pow(x, y)
# the idea here is that we have pow(x, y) has_neg = False
has_neg = False new_inputs = []
new_inputs = [] nb_neg_node = 0
nb_neg_node = 0 nb_cst = 0
nb_cst = 0 for inp in node.inputs:
for inp in node.inputs: # remove any neg arguments
# remove any neg arguments while inp.owner and inp.owner.op == neg:
while inp.owner and inp.owner.op == neg: has_neg ^= True
has_neg ^= True inp = inp.owner.inputs[0]
inp = inp.owner.inputs[0] nb_neg_node += 1
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0
# remove special case arguments of 1, -1 or 0 y = get_constant(inp)
y = get_constant(inp) if y == 1.0:
if y == 1.0: nb_cst += 1
nb_cst += 1 elif y == -1.0:
elif y == -1.0: nb_cst += 1
nb_cst += 1 has_neg ^= True # toggles
has_neg ^= True # toggles elif y == 0.0:
elif y == 0.0: # if we find any zero, we just return right away
# if we find any zero, we just return right away return [alloc_like(0, node.outputs[0], fgraph)]
return [alloc_like(0, node.outputs[0], fgraph)] else:
else: new_inputs.append(inp)
new_inputs.append(inp)
if new_inputs != node.inputs:
if new_inputs != node.inputs: if new_inputs:
if new_inputs: if len(new_inputs) == 1:
if len(new_inputs) == 1: if has_neg:
if has_neg: if new_inputs[0].dtype in ([*uint_dtypes, "bool"]):
if new_inputs[0].dtype in ([*uint_dtypes, "bool"]): return
return
else:
rval = -new_inputs[0]
else: else:
rval = new_inputs[0] rval = -new_inputs[0]
else: else:
# The next case would cause a replace by an equivalent case. rval = new_inputs[0]
if has_neg and nb_neg_node == 0 and nb_cst == 1:
return
elif has_neg:
# Don't add an extra neg node as we can't
# fully replace this mul by a neg.
m1 = np.asarray(-1, dtype=node.outputs[0].dtype)
new_inputs = [m1, *new_inputs]
rval = mul(*new_inputs)
return [alloc_like(rval, node.outputs[0], fgraph)]
else: else:
# there are no variable inputs to mul # The next case would cause a replace by an equivalent case.
# N.B. this could have been constant-folded... if has_neg and nb_neg_node == 0 and nb_cst == 1:
if has_neg: return
return [alloc_like(-1, node.outputs[0], fgraph)] elif has_neg:
else: # Don't add an extra neg node as we can't
return [alloc_like(1, node.outputs[0], fgraph)] # fully replace this mul by a neg.
m1 = np.asarray(-1, dtype=node.outputs[0].dtype)
new_inputs = [m1, *new_inputs]
rval = mul(*new_inputs)
return [alloc_like(rval, node.outputs[0], fgraph)]
else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if has_neg:
return [alloc_like(-1, node.outputs[0], fgraph)]
else:
return [alloc_like(1, node.outputs[0], fgraph)]
@register_specialize @register_specialize
...@@ -2276,7 +2220,7 @@ def local_abs_lift(fgraph, node): ...@@ -2276,7 +2220,7 @@ def local_abs_lift(fgraph, node):
This is needed for check_for_x_over_absX to apply in more case. This is needed for check_for_x_over_absX to apply in more case.
""" """
if node.op == pt_abs and node.inputs[0].owner: if node.inputs[0].owner:
assert node.nin == 1 assert node.nin == 1
if node.inputs[0].owner.op == mul: if node.inputs[0].owner.op == mul:
return [mul(*[pt_abs(i) for i in node.inputs[0].owner.inputs])] return [mul(*[pt_abs(i) for i in node.inputs[0].owner.inputs])]
...@@ -2328,31 +2272,30 @@ def local_abs_merge(fgraph, node): ...@@ -2328,31 +2272,30 @@ def local_abs_merge(fgraph, node):
def local_log1p(fgraph, node): def local_log1p(fgraph, node):
# log(1+x) -> log1p(x) # log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x) # log(1-x) -> log1p(-x)
if node.op == log: (log_arg,) = node.inputs
(log_arg,) = node.inputs if log_arg.owner and log_arg.owner.op == add:
if log_arg.owner and log_arg.owner.op == add: scalars, scalar_inputs, nonconsts = scalarconsts_rest(
scalars, scalar_inputs, nonconsts = scalarconsts_rest( log_arg.owner.inputs, only_process_constants=True
log_arg.owner.inputs, only_process_constants=True )
) # scalar_inputs are potentially dimshuffled and fill'd scalars
# scalar_inputs are potentially dimshuffled and fill'd scalars if scalars and np.allclose(np.sum(scalars), 1):
if scalars and np.allclose(np.sum(scalars), 1): if nonconsts:
if nonconsts: if len(nonconsts) > 1:
if len(nonconsts) > 1: ninp = add(*nonconsts)
ninp = add(*nonconsts) else:
else: ninp = nonconsts[0]
ninp = nonconsts[0] if ninp.dtype != log_arg.type.dtype:
if ninp.dtype != log_arg.type.dtype: ninp = ninp.astype(node.outputs[0].dtype)
ninp = ninp.astype(node.outputs[0].dtype) return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
elif log_arg.owner and log_arg.owner.op == sub:
elif log_arg.owner and log_arg.owner.op == sub: one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) if one != 1:
if one != 1: return
return other = log_arg.owner.inputs[1]
other = log_arg.owner.inputs[1] if other.dtype != log_arg.dtype:
if other.dtype != log_arg.dtype: other = other.astype(log_arg.dtype)
other = other.astype(log_arg.dtype) return [log1p(neg(other))]
return [log1p(neg(other))]
@register_stabilize @register_stabilize
...@@ -2365,26 +2308,25 @@ def local_log_add_exp(fgraph, node): ...@@ -2365,26 +2308,25 @@ def local_log_add_exp(fgraph, node):
TODO: in canonicalize, change log10 and log2 -> log TODO: in canonicalize, change log10 and log2 -> log
""" """
if node.op == log: z = node.inputs[0]
z = node.inputs[0] if z.owner and z.owner.op == add:
if z.owner and z.owner.op == add: zi = z.owner.inputs
zi = z.owner.inputs pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp] # all arguments to add are exp(<something>)
# all arguments to add are exp(<something>) if len(pre_exp) == len(zi):
if len(pre_exp) == len(zi): # Do not offset when max_pre = -np.inf, to avoid nan in the output
# Do not offset when max_pre = -np.inf, to avoid nan in the output # Switch statement is placed directly inside add to break the self-symmetry
# Switch statement is placed directly inside add to break the self-symmetry # of the returned output (otherwise the rewrite would not stabilize)
# of the returned output (otherwise the rewrite would not stabilize) max_pre = reduce(maximum, pre_exp)
max_pre = reduce(maximum, pre_exp) ret = max_pre + log(
ret = max_pre + log( add(
add( *[
*[ switch(isinf(max_pre), exp(max_pre), exp(p - max_pre))
switch(isinf(max_pre), exp(max_pre), exp(p - max_pre)) for p in pre_exp
for p in pre_exp ]
]
)
) )
return [ret] )
return [ret]
@register_stabilize @register_stabilize
...@@ -2393,9 +2335,6 @@ def local_log_add_exp(fgraph, node): ...@@ -2393,9 +2335,6 @@ def local_log_add_exp(fgraph, node):
def local_log_sum_exp(fgraph, node): def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
if node.op != log:
return
sum_node = node.inputs[0].owner sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle # If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, DimShuffle): if sum_node and isinstance(sum_node.op, DimShuffle):
...@@ -2720,8 +2659,7 @@ def local_log_erfc(fgraph, node): ...@@ -2720,8 +2659,7 @@ def local_log_erfc(fgraph, node):
numpy.asarray([i],dtype='float32')))) for i in numpy.arange( numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
10.0541948,10.0541951,.0000001)] 10.0541948,10.0541951,.0000001)]
""" """
if node.op != log:
return False
if not node.inputs[0].owner or node.inputs[0].owner.op != erfc: if not node.inputs[0].owner or node.inputs[0].owner.op != erfc:
return False return False
...@@ -2773,8 +2711,6 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2773,8 +2711,6 @@ def local_grad_log_erfc_neg(fgraph, node):
Make it so that the test does not generate an error in that case! Make it so that the test does not generate an error in that case!
""" """
if node.op != true_div:
return False
if not node.inputs[1].owner or node.inputs[1].owner.op != erfc: if not node.inputs[1].owner or node.inputs[1].owner.op != erfc:
return False return False
...@@ -3147,46 +3083,45 @@ def local_exp_over_1_plus_exp(fgraph, node): ...@@ -3147,46 +3083,45 @@ def local_exp_over_1_plus_exp(fgraph, node):
""" """
# This rewrite should be done for numerical stability # This rewrite should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
if node.op == true_div: # find all the exp() terms in the numerator
# find all the exp() terms in the numerator num, denom = node.inputs
num, denom = node.inputs num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp) denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp)
denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp)
sigmoids = []
sigmoids = [] for t in denom_1pexp:
for t in denom_1pexp: if t in num_exp_x:
if t in num_exp_x: # case: exp(x) /(1+exp(x))
# case: exp(x) /(1+exp(x)) sigmoids.append(sigmoid(t))
sigmoids.append(sigmoid(t)) del num_exp_x[num_exp_x.index(t)]
del num_exp_x[num_exp_x.index(t)]
else:
# case: 1/(1+exp(x))
sigmoids.append(sigmoid(-t))
copy_stack_trace(node.outputs[0], sigmoids[-1])
if not sigmoids: # we didn't find any. abort
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else: else:
new_num = mul(*new_num) # case: 1/(1+exp(x))
sigmoids.append(sigmoid(-t))
copy_stack_trace(node.outputs[0], sigmoids[-1])
if num_neg ^ denom_neg: if not sigmoids: # we didn't find any. abort
new_num = -new_num return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else:
new_num = mul(*new_num)
copy_stack_trace(num, new_num) if num_neg ^ denom_neg:
new_num = -new_num
if len(denom_rest) == 0: copy_stack_trace(num, new_num)
return [new_num]
elif len(denom_rest) == 1:
out = new_num / denom_rest[0]
else:
out = new_num / mul(*denom_rest)
copy_stack_trace(node.outputs[0], out) if len(denom_rest) == 0:
return [out] return [new_num]
elif len(denom_rest) == 1:
out = new_num / denom_rest[0]
else:
out = new_num / mul(*denom_rest)
copy_stack_trace(node.outputs[0], out)
return [out]
def parse_mul_tree(root): def parse_mul_tree(root):
...@@ -3498,9 +3433,6 @@ def local_sigm_times_exp(fgraph, node): ...@@ -3498,9 +3433,6 @@ def local_sigm_times_exp(fgraph, node):
todo: add stack traces to the intermediate variables todo: add stack traces to the intermediate variables
""" """
# Bail early if it is not a multiplication.
if node.op != mul:
return None
# Obtain tree of multiplications starting at this node. # Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0]) mul_tree = parse_mul_tree(node.outputs[0])
did_something = perform_sigm_times_exp(mul_tree) did_something = perform_sigm_times_exp(mul_tree)
...@@ -3528,31 +3460,30 @@ def local_reciprocal_1_plus_exp(fgraph, node): ...@@ -3528,31 +3460,30 @@ def local_reciprocal_1_plus_exp(fgraph, node):
""" """
# This Rewrite should be done for numerical stability # This Rewrite should be done for numerical stability
# so we don't care to check client counts # so we don't care to check client counts
if node.op == reciprocal: reciprocal_arg = node.inputs[0]
reciprocal_arg = node.inputs[0] if reciprocal_arg.owner and reciprocal_arg.owner.op == add:
if reciprocal_arg.owner and reciprocal_arg.owner.op == add: scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
scalars_, scalar_inputs, nonconsts = scalarconsts_rest( reciprocal_arg.owner.inputs, only_process_constants=True
reciprocal_arg.owner.inputs, only_process_constants=True )
) # scalar_inputs are potentially dimshuffled and fill'd scalars
# scalar_inputs are potentially dimshuffled and fill'd scalars if len(nonconsts) == 1:
if len(nonconsts) == 1: if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if nonconsts[0].owner and nonconsts[0].owner.op == exp: if scalars_ and np.allclose(np.sum(scalars_), 1):
if scalars_ and np.allclose(np.sum(scalars_), 1): out = [
out = [ alloc_like(
alloc_like( sigmoid(neg(nonconsts[0].owner.inputs[0])),
sigmoid(neg(nonconsts[0].owner.inputs[0])), node.outputs[0],
node.outputs[0], fgraph,
fgraph,
)
]
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace(
[nonconsts[0], reciprocal_arg, node.outputs[0]], out
) )
return out ]
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace(
[nonconsts[0], reciprocal_arg, node.outputs[0]], out
)
return out
# 1 - sigmoid(x) -> sigmoid(-x) # 1 - sigmoid(x) -> sigmoid(-x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论