提交 380fca03 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.typed_list.opt to aesara.typed_list.rewriting

上级 8616d398
from . import opt from aesara.typed_list import rewriting
from .basic import * from aesara.typed_list.basic import *
from .type import TypedListType from aesara.typed_list.type import TypedListType
...@@ -4,7 +4,7 @@ from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse ...@@ -4,7 +4,7 @@ from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse
@node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True) @node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True)
def typed_list_inplace_opt(fgraph, node): def typed_list_inplace_rewrite(fgraph, node):
if ( if (
isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) isinstance(node.op, (Append, Extend, Insert, Reverse, Remove))
and not node.op.inplace and not node.op.inplace
...@@ -17,9 +17,9 @@ def typed_list_inplace_opt(fgraph, node): ...@@ -17,9 +17,9 @@ def typed_list_inplace_opt(fgraph, node):
optdb.register( optdb.register(
"typed_list_inplace_opt", "typed_list_inplace_rewrite",
WalkingGraphRewriter( WalkingGraphRewriter(
typed_list_inplace_opt, failure_callback=WalkingGraphRewriter.warn_inplace typed_list_inplace_rewrite, failure_callback=WalkingGraphRewriter.warn_inplace
), ),
"fast_run", "fast_run",
"inplace", "inplace",
......
...@@ -17,7 +17,9 @@ class TestInplace: ...@@ -17,7 +17,9 @@ class TestInplace:
)() )()
z = Reverse()(mySymbolicMatricesList) z = Reverse()(mySymbolicMatricesList)
m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") m = aesara.compile.mode.get_default_mode().including(
"typed_list_inplace_rewrite"
)
f = aesara.function( f = aesara.function(
[In(mySymbolicMatricesList, borrow=True, mutable=True)], [In(mySymbolicMatricesList, borrow=True, mutable=True)],
z, z,
...@@ -38,7 +40,9 @@ class TestInplace: ...@@ -38,7 +40,9 @@ class TestInplace:
)() )()
mySymbolicMatrix = matrix() mySymbolicMatrix = matrix()
z = Append()(mySymbolicMatricesList, mySymbolicMatrix) z = Append()(mySymbolicMatricesList, mySymbolicMatrix)
m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") m = aesara.compile.mode.get_default_mode().including(
"typed_list_inplace_rewrite"
)
f = aesara.function( f = aesara.function(
[ [
In(mySymbolicMatricesList, borrow=True, mutable=True), In(mySymbolicMatricesList, borrow=True, mutable=True),
...@@ -66,7 +70,9 @@ class TestInplace: ...@@ -66,7 +70,9 @@ class TestInplace:
)() )()
z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2) z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2)
m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") m = aesara.compile.mode.get_default_mode().including(
"typed_list_inplace_rewrite"
)
f = aesara.function( f = aesara.function(
[ [
In(mySymbolicMatricesList1, borrow=True, mutable=True), In(mySymbolicMatricesList1, borrow=True, mutable=True),
...@@ -91,7 +97,9 @@ class TestInplace: ...@@ -91,7 +97,9 @@ class TestInplace:
mySymbolicMatrix = matrix() mySymbolicMatrix = matrix()
z = Insert()(mySymbolicMatricesList, mySymbolicIndex, mySymbolicMatrix) z = Insert()(mySymbolicMatricesList, mySymbolicIndex, mySymbolicMatrix)
m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") m = aesara.compile.mode.get_default_mode().including(
"typed_list_inplace_rewrite"
)
f = aesara.function( f = aesara.function(
[ [
...@@ -117,7 +125,9 @@ class TestInplace: ...@@ -117,7 +125,9 @@ class TestInplace:
)() )()
mySymbolicMatrix = matrix() mySymbolicMatrix = matrix()
z = Remove()(mySymbolicMatricesList, mySymbolicMatrix) z = Remove()(mySymbolicMatricesList, mySymbolicMatrix)
m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") m = aesara.compile.mode.get_default_mode().including(
"typed_list_inplace_rewrite"
)
f = aesara.function( f = aesara.function(
[ [
In(mySymbolicMatricesList, borrow=True, mutable=True), In(mySymbolicMatricesList, borrow=True, mutable=True),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论