提交 219d6431 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added random->inplace random optimizer

上级 8a1b6d8a
...@@ -94,7 +94,7 @@ inplace_optimizer = gof.InplaceOptimizer( ...@@ -94,7 +94,7 @@ inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1), gof.SeqOptimizer(out2in(gemm_pattern_1),
insert_inplace_optimizer, insert_inplace_optimizer,
failure_callback = gof.warn)) failure_callback = gof.warn))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run') compile.optdb.register('inplace_opt', inplace_optimizer, 99, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
......
...@@ -4,6 +4,7 @@ import basic as tensor ...@@ -4,6 +4,7 @@ import basic as tensor
import numpy import numpy
import functools import functools
import opt
from .. import compile from .. import compile
from ..compile import SymbolicInputKit, SymbolicInput from ..compile import SymbolicInputKit, SymbolicInput
from copy import copy from copy import copy
...@@ -178,12 +179,15 @@ to supplement the missing information. ...@@ -178,12 +179,15 @@ to supplement the missing information.
""" """
@gof.local_optimizer @gof.local_optimizer([None])
def random_make_inplace(node): def random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs
compile.optdb.register('random_make_inplace', opt.in2out(random_make_inplace), 99, 'fast_run', 'inplace')
import sys import sys
from functools import partial from functools import partial
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论