Use stricter numerical tolerance in rewrites and allow casting in `PatternNodeRewriter` (#1526)
* Implemented allow_cast in PatternNodeRewriter
to allow rewrites that would otherwise fail when the new and old dtype differ.
Example:
`np.array(1., "float64") - sigmoid(x)` cannot be rewritten as
`sigmoid(-x)` (where x is an fmatrix) because the type would change.
This commit allows an automatic cast to be added so the expression
is rewritten as `cast(sigmoid(-x), "float64")`.
Relevant tests added.
* Added test cases for which issue #1497 fails
* Changed PatternNodeRewriter::transform to allow types that do not contain dtype
like MyType in the tests
* Address #1497 by changing instances of np.isclose to a function isclose, which uses 10 ULPs by default
* Addressed failed tests (with older python/numpy versions)
* Addressed feedback by ricardoV94
* Test PatternNodeRewriter doesn't support multi-output nodes in pattern
But it's fine if they're just root inputs
---------
Co-authored-by:
Luca Citi <lciti@ieee.org>
Co-authored-by:
Ricardo Vieira <ricardo.vieira1994@gmail.com>
正在显示
请
注册
或者
登录
后发表评论