提交 06632882 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Move math-related optimizations to theano.tensor.math_opt

上级 980451b5
===================================================================
:mod:`tensor.math_opt` -- Tensor Optimizations for Math Operations
===================================================================
.. module:: tensor.math_opt
:platform: Unix, Windows
:synopsis: Tensor Optimizations for Math Operations
.. moduleauthor:: LISA, PyMC Developers
.. automodule:: theano.tensor.math_opt
:members:
...@@ -143,7 +143,7 @@ Optimization o4 o3 o2 ...@@ -143,7 +143,7 @@ Optimization o4 o3 o2
(a+b+c+...) - (z + x + y + ....) (a+b+c+...) - (z + x + y + ....)
See :class:`Canonizer`, :attr:`local_add_canonizer` See :class:`AlgebraicCanonizer`, :attr:`local_add_canonizer`
mul canonicalization mul canonicalization
Rearrange expressions of multiplication and division to a canonical Rearrange expressions of multiplication and division to a canonical
...@@ -153,7 +153,7 @@ Optimization o4 o3 o2 ...@@ -153,7 +153,7 @@ Optimization o4 o3 o2
\frac{a * b * c * ...}{z * x * y * ....} \frac{a * b * c * ...}{z * x * y * ....}
See :class:`Canonizer`, :attr:`local_mul_canonizer` See :class:`AlgebraicCanonizer`, :attr:`local_mul_canonizer`
dot22 dot22
This simple optimization replaces dot(matrix, matrix) with a special This simple optimization replaces dot(matrix, matrix) with a special
...@@ -288,4 +288,3 @@ Optimization o4 o3 o2 ...@@ -288,4 +288,3 @@ Optimization o4 o3 o2
Use this optimization if you are sure everything is valid in your graph. Use this optimization if you are sure everything is valid in your graph.
See :ref:`unsafe_optimization` See :ref:`unsafe_optimization`
...@@ -47,11 +47,8 @@ from theano.tensor.basic_opt import ( ...@@ -47,11 +47,8 @@ from theano.tensor.basic_opt import (
MakeVector, MakeVector,
ShapeFeature, ShapeFeature,
assert_op, assert_op,
local_add_specialize,
local_canonicalize_alloc, local_canonicalize_alloc,
local_dimshuffle_lift, local_dimshuffle_lift,
local_greedy_distributor,
local_lift_transpose_through_dot,
local_merge_alloc, local_merge_alloc,
local_reshape_to_dimshuffle, local_reshape_to_dimshuffle,
local_useless_alloc, local_useless_alloc,
...@@ -59,7 +56,6 @@ from theano.tensor.basic_opt import ( ...@@ -59,7 +56,6 @@ from theano.tensor.basic_opt import (
local_useless_elemwise, local_useless_elemwise,
local_useless_reshape, local_useless_reshape,
make_vector, make_vector,
mul_canonizer,
register_specialize, register_specialize,
) )
from theano.tensor.blas import Dot22, Gemv from theano.tensor.blas import Dot22, Gemv
...@@ -109,6 +105,12 @@ from theano.tensor.math import round as tt_round ...@@ -109,6 +105,12 @@ from theano.tensor.math import round as tt_round
from theano.tensor.math import sgn, sin, sinh, sqr, sqrt, sub from theano.tensor.math import sgn, sin, sinh, sqr, sqrt, sub
from theano.tensor.math import sum as tt_sum from theano.tensor.math import sum as tt_sum
from theano.tensor.math import tan, tanh, true_div, xor from theano.tensor.math import tan, tanh, true_div, xor
from theano.tensor.math_opt import (
local_add_specialize,
local_greedy_distributor,
local_lift_transpose_through_dot,
mul_canonizer,
)
from theano.tensor.nnet.sigm import softplus from theano.tensor.nnet.sigm import softplus
from theano.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape from theano.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape
from theano.tensor.subtensor import ( from theano.tensor.subtensor import (
...@@ -465,7 +467,7 @@ class TestCanonize: ...@@ -465,7 +467,7 @@ class TestCanonize:
print(pprint(g.outputs[0])) print(pprint(g.outputs[0]))
def test_elemwise_multiple_inputs_optimisation(self): def test_elemwise_multiple_inputs_optimisation(self):
# verify that the Canonizer merge sequential Elemwise({mul,add}) part 1 # verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 1
# #
# This part are that case that is done, but don't include case # This part are that case that is done, but don't include case
# that are not implemented but are supposed to be. # that are not implemented but are supposed to be.
...@@ -574,8 +576,8 @@ class TestCanonize: ...@@ -574,8 +576,8 @@ class TestCanonize:
] # [10:11] ] # [10:11]
# print cases # print cases
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = Query(["canonicalize"]) opt = Query(["canonicalize"])
opt = opt.excluding("local_elemwise_fusion") opt = opt.excluding("local_elemwise_fusion")
...@@ -595,11 +597,11 @@ class TestCanonize: ...@@ -595,11 +597,11 @@ class TestCanonize:
assert out_dtype == out.dtype assert out_dtype == out.dtype
@pytest.mark.skip( @pytest.mark.skip(
reason="Current implementation of Canonizer does not " reason="Current implementation of AlgebraicCanonizer does not "
"implement all cases. Skip the corresponding test." "implement all cases. Skip the corresponding test."
) )
def test_elemwise_multiple_inputs_optimisation2(self): def test_elemwise_multiple_inputs_optimisation2(self):
# verify that the Canonizer merge sequential Elemwise({mul,add}) part 2. # verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 2.
# This part are that case that should have been done, but that are not implemented. # This part are that case that should have been done, but that are not implemented.
# Test with and without DimShuffle # Test with and without DimShuffle
...@@ -709,8 +711,8 @@ class TestCanonize: ...@@ -709,8 +711,8 @@ class TestCanonize:
] # [10:11] ] # [10:11]
# print cases # print cases
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
mode._optimizer = Query(["canonicalize"]) mode._optimizer = Query(["canonicalize"])
mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion") mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion")
...@@ -728,7 +730,7 @@ class TestCanonize: ...@@ -728,7 +730,7 @@ class TestCanonize:
@pytest.mark.slow @pytest.mark.slow
def test_multiple_case(self): def test_multiple_case(self):
# test those case take from the comment in Canonizer # test those case take from the comment in AlgebraicCanonizer
# x / x -> 1 # x / x -> 1
# (x * y) / x -> y # (x * y) / x -> y
# x / y / x -> 1 / y # x / y / x -> 1 / y
...@@ -756,8 +758,8 @@ class TestCanonize: ...@@ -756,8 +758,8 @@ class TestCanonize:
dwv = _asarray(np.random.rand(*shp), dtype="float64") dwv = _asarray(np.random.rand(*shp), dtype="float64")
dvv = _asarray(np.random.rand(shp[0]), dtype="float64").reshape(1, shp[0]) dvv = _asarray(np.random.rand(shp[0]), dtype="float64").reshape(1, shp[0])
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = Query(["canonicalize"]) opt = Query(["canonicalize"])
...@@ -1109,7 +1111,7 @@ class TestCanonize: ...@@ -1109,7 +1111,7 @@ class TestCanonize:
assert f.maker.fgraph.toposort()[0].op == sgn assert f.maker.fgraph.toposort()[0].op == sgn
@pytest.mark.skip( @pytest.mark.skip(
reason="Current implementation of Canonizer does not " reason="Current implementation of AlgebraicCanonizer does not "
"implement all cases. Skip the corresponding test." "implement all cases. Skip the corresponding test."
) )
def test_multiple_case_that_fail(self): def test_multiple_case_that_fail(self):
...@@ -1123,8 +1125,8 @@ class TestCanonize: ...@@ -1123,8 +1125,8 @@ class TestCanonize:
dyv = _asarray(np.random.rand(*shp), dtype="float32") dyv = _asarray(np.random.rand(*shp), dtype="float32")
dzv = _asarray(np.random.rand(*shp), dtype="float32") dzv = _asarray(np.random.rand(*shp), dtype="float32")
# fvv = _asarray(np.random.rand(shp[0]), dtype='float32').reshape(1, shp[0]) # fvv = _asarray(np.random.rand(shp[0]), dtype='float32').reshape(1, shp[0])
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode = get_default_mode() mode = get_default_mode()
opt = Query(["canonicalize"]) opt = Query(["canonicalize"])
......
...@@ -86,7 +86,7 @@ from theano.scan.utils import ( ...@@ -86,7 +86,7 @@ from theano.scan.utils import (
scan_args, scan_args,
scan_can_remove_outs, scan_can_remove_outs,
) )
from theano.tensor import basic_opt from theano.tensor import basic_opt, math_opt
from theano.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value from theano.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from theano.tensor.elemwise import DimShuffle, Elemwise from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.exceptions import NotScalarConstantError from theano.tensor.exceptions import NotScalarConstantError
...@@ -118,8 +118,8 @@ __copyright__ = "(c) 2010, Universite de Montreal" ...@@ -118,8 +118,8 @@ __copyright__ = "(c) 2010, Universite de Montreal"
_logger = logging.getLogger("theano.scan.opt") _logger = logging.getLogger("theano.scan.opt")
list_opt_slice = [ list_opt_slice = [
basic_opt.local_abs_merge, math_opt.local_abs_merge,
basic_opt.local_mul_switch_sink, math_opt.local_mul_switch_sink,
basic_opt.local_upcast_elemwise_constant_inputs, basic_opt.local_upcast_elemwise_constant_inputs,
basic_opt.local_useless_switch, basic_opt.local_useless_switch,
basic_opt.constant_folding, basic_opt.constant_folding,
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论