提交 3685fd8e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move GraphToGpuLocalOptGroup to aesara.gpuarray.optdb

上级 18b6925b
import time
from aesara.compile import optdb
from aesara.graph.opt import GraphToGPULocalOptGroup, TopoOptimizer, local_optimizer
from aesara.graph.basic import applys_between
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, local_optimizer
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
......@@ -8,6 +11,43 @@ from aesara.graph.optdb import (
)
class GraphToGPULocalOptGroup(LocalOptGroup):
"""This is the equivalent of `LocalOptGroup` for `GraphToGPU`.
The main different is the function signature of the local
optimizer that use the `GraphToGPU` signature and not the normal
`LocalOptimizer` signature.
``apply_all_opts=True`` is not supported
"""
def __init__(self, *optimizers, **kwargs):
super().__init__(*optimizers, **kwargs)
assert self.apply_all_opts is False
def transform(self, fgraph, op, context_name, inputs, outputs):
if len(self.opts) == 0:
return
opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None]
for opt in opts:
opt_start = time.time()
new_repl = opt.transform(fgraph, op, context_name, inputs, outputs)
opt_finish = time.time()
if self.profile:
self.time_opts[opt] += opt_start - opt_finish
self.process_count[opt] += 1
if not new_repl:
continue
if self.profile:
self.node_created[opt] += len(
list(applys_between(fgraph.variables, new_repl))
)
self.applied_true[opt] += 1
return new_repl
gpu_optimizer = EquilibriumDB()
gpu_cut_copies = EquilibriumDB()
......
......@@ -1444,43 +1444,6 @@ class LocalOptGroup(LocalOptimizer):
opt.add_requirements(fgraph)
class GraphToGPULocalOptGroup(LocalOptGroup):
"""This is the equivalent of `LocalOptGroup` for `GraphToGPU`.
The main different is the function signature of the local
optimizer that use the `GraphToGPU` signature and not the normal
`LocalOptimizer` signature.
``apply_all_opts=True`` is not supported
"""
def __init__(self, *optimizers, **kwargs):
super().__init__(*optimizers, **kwargs)
assert self.apply_all_opts is False
def transform(self, fgraph, op, context_name, inputs, outputs):
if len(self.opts) == 0:
return
opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None]
for opt in opts:
opt_start = time.time()
new_repl = opt.transform(fgraph, op, context_name, inputs, outputs)
opt_finish = time.time()
if self.profile:
self.time_opts[opt] += opt_start - opt_finish
self.process_count[opt] += 1
if not new_repl:
continue
if self.profile:
self.node_created[opt] += len(
list(applys_between(fgraph.variables, new_repl))
)
self.applied_true[opt] += 1
return new_repl
class OpSub(LocalOptimizer):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论