提交 c49e395c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Merge rewrite for sum/prod of div with that of mul

上级 3b0a97b7
...@@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node): ...@@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_of_mul(fgraph, node): def local_sum_prod_of_mul_or_div(fgraph, node):
""" """
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
...@@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node):
prod(a * X) -> (a ** size(X)) * prod(X) prod(a * X) -> (a ** size(X)) * prod(X)
It also applies to reduction of X / a,
but not a / X, as that would still require inverting every value in X before the reduction
TODO: In the case where not all axis overlap with broadcast dimensions, TODO: In the case where not all axis overlap with broadcast dimensions,
consider introducing an outer reduction after factoring out the compatible reduced dimensions consider introducing an outer reduction after factoring out the compatible reduced dimensions
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1) E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
""" """
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
[node_inps] = node.inputs [node_inps] = node.inputs
if not (node_inps.owner and node_inps.owner.op == mul): if not node_inps.owner:
return None
inner_op = node_inps.owner.op
if not (inner_op == mul or inner_op == true_div):
return None return None
reduced_axes = node.op.axis reduced_axes = node.op.axis
...@@ -1214,28 +1219,40 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1214,28 +1219,40 @@ def local_sum_prod_of_mul(fgraph, node):
reduced_axes = tuple(range(node_inps.type.ndim)) reduced_axes = tuple(range(node_inps.type.ndim))
# Separate terms that can be moved out of the Sum/Prod and those that cannot # Separate terms that can be moved out of the Sum/Prod and those that cannot
outer_terms = [] if inner_op == mul:
inner_terms = [] # Mul accepts arbitrary inputs, so we need to separate into two groups
for term in node_inps.owner.inputs: outer_terms = []
term_bcast = term.type.broadcastable inner_terms = []
if all(term_bcast[i] for i in reduced_axes): for term in node_inps.owner.inputs:
outer_terms.append(term.squeeze(reduced_axes)) term_bcast = term.type.broadcastable
else: if all(term_bcast[i] for i in reduced_axes):
inner_terms.append(term) outer_terms.append(term.squeeze(reduced_axes))
else:
inner_terms.append(term)
if not outer_terms: if not outer_terms:
return None return None
elif len(outer_terms) == 1: elif len(outer_terms) == 1:
[outer_term] = outer_terms [outer_term] = outer_terms
else: else:
outer_term = mul(*outer_terms) outer_term = mul(*outer_terms)
if not inner_terms: if not inner_terms:
inner_term = None inner_term = None
elif len(inner_terms) == 1: elif len(inner_terms) == 1:
[inner_term] = inner_terms [inner_term] = inner_terms
else: else:
inner_term = mul(*inner_terms) inner_term = mul(*inner_terms)
else: # true_div
# We only care about removing the denominator out of the reduction
numerator, denominator = node_inps.owner.inputs
denominator_bcast = denominator.type.broadcastable
if all(denominator_bcast[i] for i in reduced_axes):
outer_term = denominator.squeeze(reduced_axes)
inner_term = numerator
else:
return None
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements # If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input # that were contracted in the input
...@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node):
) )
outer_term = outer_term**n_reduced_elements outer_term = outer_term**n_reduced_elements
# Sum/Prod is useless, just return the outer_term
if not inner_term: if not inner_term:
# Sum/Prod is useless, just return the outer_term
# (This can only happen for mul, not division)
new_out = outer_term new_out = outer_term
else: else:
reduced_inner_term = node.op(inner_term) reduced_inner_term = node.op(inner_term)
new_out = outer_term * reduced_inner_term if inner_op == mul:
new_out = outer_term * reduced_inner_term
else:
new_out = reduced_inner_term / outer_term
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term]) copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
copy_stack_trace(node.outputs, new_out) copy_stack_trace(node.outputs, new_out)
...@@ -1510,99 +1531,6 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1510,99 +1531,6 @@ def local_useless_elemwise_comparison(fgraph, node):
return return
@register_canonicalize
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_div_dimshuffle(fgraph, node):
"""
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'
or
prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
"""
# It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation or production.
if isinstance(node.op, (Sum, Prod)):
axis = node.op.axis
if axis is None:
axis = list(range(node.inputs[0].ndim))
node_input = node.inputs[0]
if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order
compatible_dims = []
incompatible_dims = []
for ax in axis:
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
compatible_dims.append(ax)
else:
incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
)
if len(compatible_dims) > 0:
optimized_dimshuffle_order = [
ax
for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x")
]
# Removing leading 'x' (since it will be done automatically)
while (
len(optimized_dimshuffle_order) > 0
and optimized_dimshuffle_order[0] == "x"
):
del optimized_dimshuffle_order[0]
# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order,
)(dimshuffle_input)
if isinstance(node.op, Sum):
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
if len(reordered_incompatible_dims) > 0:
rval = at_sum(rval, axis=reordered_incompatible_dims)
elif isinstance(node.op, Prod):
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
dtype = numerator.dtype
rval = true_div(
op_on_compatible_dims,
(
optimized_dimshuffle
** prod(
[
numerator.shape[ax].astype(dtype)
for ax in compatible_dims
]
)
),
)
if len(reordered_incompatible_dims) > 0:
rval = prod(rval, axis=reordered_incompatible_dims)
return [rval]
@register_canonicalize @register_canonicalize
@node_rewriter([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_all_to_none(fgraph, node): def local_sum_prod_all_to_none(fgraph, node):
......
...@@ -899,7 +899,7 @@ class TestFusion: ...@@ -899,7 +899,7 @@ class TestFusion:
), ),
(fx, fy), (fx, fy),
(fxv, fyv), (fxv, fyv),
3, 2,
( (
np.sum(-((fxv - fyv) ** 2) / 2), np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv), -(fxv - fyv),
......
...@@ -92,7 +92,7 @@ from pytensor.tensor.rewriting.math import ( ...@@ -92,7 +92,7 @@ from pytensor.tensor.rewriting.math import (
local_grad_log_erfc_neg, local_grad_log_erfc_neg,
local_greedy_distributor, local_greedy_distributor,
local_mul_canonizer, local_mul_canonizer,
local_sum_prod_of_mul, local_sum_prod_of_mul_or_div,
mul_canonizer, mul_canonizer,
parse_mul_tree, parse_mul_tree,
perform_sigm_times_exp, perform_sigm_times_exp,
...@@ -2656,7 +2656,7 @@ class TestLocalSumProd: ...@@ -2656,7 +2656,7 @@ class TestLocalSumProd:
def test_sum_of_non_scalar_mul(self): def test_sum_of_non_scalar_mul(self):
mode = Mode("vm", optimizer="None") mode = Mode("vm", optimizer="None")
rewrite = out2in(local_sum_prod_of_mul) rewrite = out2in(local_sum_prod_of_mul_or_div)
row1 = matrix(shape=(1, None), dtype="float64") row1 = matrix(shape=(1, None), dtype="float64")
row2 = matrix(shape=(1, None), dtype="float64") row2 = matrix(shape=(1, None), dtype="float64")
...@@ -2726,7 +2726,7 @@ class TestLocalSumProd: ...@@ -2726,7 +2726,7 @@ class TestLocalSumProd:
def test_prod_of_non_scalar_mul(self): def test_prod_of_non_scalar_mul(self):
mode = Mode("vm", optimizer="None") mode = Mode("vm", optimizer="None")
rewrite = out2in(local_sum_prod_of_mul) rewrite = out2in(local_sum_prod_of_mul_or_div)
scl1 = matrix(shape=(1, 1), dtype="float64") scl1 = matrix(shape=(1, 1), dtype="float64")
row1 = matrix(shape=(1, None), dtype="float64") row1 = matrix(shape=(1, None), dtype="float64")
...@@ -2756,14 +2756,15 @@ class TestLocalSumProd: ...@@ -2756,14 +2756,15 @@ class TestLocalSumProd:
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0), mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0),
( (
mul(row1.squeeze(), row2.squeeze()) mul(row1.squeeze(), row2.squeeze())
** prod([mul(mat1, mat2, col1, col2).shape[0]]) ** prod([mul(mat1, mat2, col1, col2).shape[0].astype("float64")])
* mul(mat1, mat2, col1, col2).prod(axis=0) * mul(mat1, mat2, col1, col2).prod(axis=0)
), ),
), ),
( (
mul(row1, mat1, mat2, col1, col2).prod(axis=0), mul(row1, mat1, mat2, col1, col2).prod(axis=0),
( (
row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]]) row1.squeeze()
** prod([mul(mat1, mat2, col1, col2).shape[0].astype("float64")])
* mul(mat1, mat2, col1, col2).prod(axis=0) * mul(mat1, mat2, col1, col2).prod(axis=0)
), ),
), ),
...@@ -2771,7 +2772,7 @@ class TestLocalSumProd: ...@@ -2771,7 +2772,7 @@ class TestLocalSumProd:
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1), mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1),
( (
mul(col1.squeeze(), col2.squeeze()) mul(col1.squeeze(), col2.squeeze())
** prod([mul(row1, row2, mat1, mat2).shape[1]]) ** prod([mul(row1, row2, mat1, mat2).shape[1].astype("float64")])
* mul(row1, row2, mat1, mat2).prod(axis=1) * mul(row1, row2, mat1, mat2).prod(axis=1)
), ),
), ),
...@@ -2781,13 +2782,21 @@ class TestLocalSumProd: ...@@ -2781,13 +2782,21 @@ class TestLocalSumProd:
), ),
( (
mul(row1, col1).prod(axis=0), mul(row1, col1).prod(axis=0),
(row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)), (
row1.squeeze() ** prod([col1.shape[0].astype("float64")])
* col1.prod(axis=0)
),
), ),
( (
mul(scl1, mat1, row1).prod(axis=None), mul(scl1, mat1, row1).prod(axis=None),
( (
scl1.squeeze() scl1.squeeze()
** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]]) ** prod(
[
mul(mat1, row1).shape[0].astype("float64"),
mul(mat1, row1).shape[1].astype("float64"),
]
)
* mul(mat1, row1).prod(axis=None) * mul(mat1, row1).prod(axis=None)
), ),
), ),
...@@ -3050,146 +3059,7 @@ class TestLocalSumProd: ...@@ -3050,146 +3059,7 @@ class TestLocalSumProd:
f = function([mat], at_sum(-mat), mode=m0) f = function([mat], at_sum(-mat), mode=m0)
assert check_stack_trace(f, ops_to_check=[Sum]) assert check_stack_trace(f, ops_to_check=[Sum])
def test_local_sum_of_div(self):
class TestLocalReduce:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
)
def test_local_reduce_broadcast_all_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x)], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_all_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_some_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, None, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
order = f.maker.fgraph.toposort()
assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
node = [node for node in order if isinstance(node.op, CAReduce)][0]
op = node.op
assert isinstance(op, CAReduce)
# The leading broadcastable dimension has been dropped by the
# `local_reduce_broadcastable` rewrite. Now, summation is over
# the original `x`'s dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_reduce_broadcast_some_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x, axis=[0, 2])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_join(self):
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(at_max((vx, vy), 0), np.max((x, y), 0)),
(at_min((vx, vy), 0), np.min((x, y), 0)),
(at_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
f = function([], at_sum(at.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)
# Test a case that was bugged in a old PyTensor bug
f = function([], at_sum(at.stack([A, A]), axis=1), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
out = at_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out)
class TestLocalSumProdDimshuffle:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize")
def test_local_sum_div_dimshuffle(self):
a = matrix("a") a = matrix("a")
b = vector("b") b = vector("b")
c = tensor3("c") c = tensor3("c")
...@@ -3242,7 +3112,7 @@ class TestLocalSumProdDimshuffle: ...@@ -3242,7 +3112,7 @@ class TestLocalSumProdDimshuffle:
assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv) assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv)
f(a_val, b_val, c_val, d_val) f(a_val, b_val, c_val, d_val)
def test_local_prod_div_dimshuffle(self): def test_local_prod_of_div(self):
a = matrix("a") a = matrix("a")
b = vector("b") b = vector("b")
c = tensor3("c") c = tensor3("c")
...@@ -3295,9 +3165,9 @@ class TestLocalSumProdDimshuffle: ...@@ -3295,9 +3165,9 @@ class TestLocalSumProdDimshuffle:
# `FusionOptimizer` is included to make sure that `expected_outer_operator` # `FusionOptimizer` is included to make sure that `expected_outer_operator`
# remains the same for all rewrite modes. # remains the same for all rewrite modes.
mode_with_rewrite = default_mode.including( mode_with_rewrite = default_mode.including(
"local_sum_prod_div_dimshuffle", "FusionOptimizer" "local_sum_prod_of_mul_or_div", "FusionOptimizer"
) )
mode_without_rewrite = default_mode.excluding("local_sum_prod_div_dimshuffle") mode_without_rewrite = default_mode.excluding("local_sum_prod_of_mul_or_div")
# Numerical tests: tests whether the numerical values with and without # Numerical tests: tests whether the numerical values with and without
# rewrites are equal or not. # rewrites are equal or not.
...@@ -3345,9 +3215,139 @@ class TestLocalSumProdDimshuffle: ...@@ -3345,9 +3215,139 @@ class TestLocalSumProdDimshuffle:
g.maker.fgraph.toposort()[-1].op.scalar_op, expected_outer_operator[i] g.maker.fgraph.toposort()[-1].op.scalar_op, expected_outer_operator[i]
) )
# TODO:
# test_local_sum_prod_dimshuffle (a * b * c) class TestLocalReduce:
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d)) def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
)
def test_local_reduce_broadcast_all_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x)], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_all_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_some_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, None, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
order = f.maker.fgraph.toposort()
assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
node = [node for node in order if isinstance(node.op, CAReduce)][0]
op = node.op
assert isinstance(op, CAReduce)
# The leading broadcastable dimension has been dropped by the
# `local_reduce_broadcastable` rewrite. Now, summation is over
# the original `x`'s dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_reduce_broadcast_some_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x, axis=[0, 2])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_join(self):
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(at_max((vx, vy), 0), np.max((x, y), 0)),
(at_min((vx, vy), 0), np.min((x, y), 0)),
(at_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
f = function([], at_sum(at.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)
# Test a case that was bugged in a old PyTensor bug
f = function([], at_sum(at.stack([A, A]), axis=1), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
out = at_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out)
def test_local_useless_adds(): def test_local_useless_adds():
...@@ -3534,7 +3534,6 @@ def test_local_mul_exp_to_exp_add(): ...@@ -3534,7 +3534,6 @@ def test_local_mul_exp_to_exp_add():
# e^x * e^y * e^z * e^w = e^(x+y+z+w) # e^x * e^y * e^z * e^w = e^(x+y+z+w)
op = expx * expy * expz * expw op = expx * expy * expz * expw
f = function([x, y, z, w], op, mode) f = function([x, y, z, w], op, mode)
pytensor.dprint(f)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6)) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph) assert all(isinstance(n.op, Elemwise) for n in graph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论