• Luca Citi's avatar
    Use stricter numerical tolerance in rewrites and allow casting in `PatternNodeRewriter` (#1526) · d4e8f736
    Luca Citi 提交于
    * Implemented allow_cast in PatternNodeRewriter
    to allow rewrites that would otherwise fail when the new and old dtype differ.
    Example:
    `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as
    `sigmoid(-x)` (where x is an fmatrix) because the type would change.
    This commit allows an automatic cast to be added so the expression
    is rewritten as `cast(sigmoid(-x), "float64")`.
    Relevant tests added.
    
    * Added test cases for which issue #1497 fails
    
    * Changed PatternNodeRewriter::transform to allow types that do not contain dtype
    like MyType in the tests
    
    * Address #1497 by changing instances of np.isclose to a function isclose, which uses 10 ULPs by default
    
    * Addressed failed tests (with older python/numpy versions)
    
    * Addressed feedback by ricardoV94
    
    * Test PatternNodeRewriter doesn't support multi-output nodes in pattern
    
    But it's fine if they're just root inputs
    
    ---------
    Co-authored-by: 's avatarLuca Citi <lciti@ieee.org>
    Co-authored-by: 's avatarRicardo Vieira <ricardo.vieira1994@gmail.com>
    d4e8f736
unittest_tools.py 14.7 KB