提交 c655b028 authored 作者: David Horsley's avatar David Horsley 提交者: Ricardo Vieira

Split blas Ops and rewrites

Having Ops and rewrites in the same files was causing circular imports.
上级 86cbde57
差异被折叠。
from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import in2out
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.scalar import bool as bool_t from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as at
from pytensor.tensor.blas import ( from pytensor.tensor.blas import (
Gemv, Gemv,
Ger, Ger,
blas_header_text, blas_header_text,
blas_header_version, blas_header_version,
blas_optdb,
gemv_inplace,
gemv_no_inplace,
ger,
ger_destructive,
ldflags, ldflags,
node_rewriter,
optdb,
) )
...@@ -344,23 +334,6 @@ cger_inplace = CGer(True) ...@@ -344,23 +334,6 @@ cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
@node_rewriter([ger, ger_destructive])
def use_c_ger(fgraph, node):
if not config.blas__ldflags:
return
# Only float32 and float64 are supported for now.
if node.op == ger and node.outputs[0].dtype in ("float32", "float64"):
return [CGer(False)(*node.inputs)]
if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"):
return [CGer(True)(*node.inputs)]
@node_rewriter([CGer(False)])
def make_c_ger_destructive(fgraph, node):
if isinstance(node.op, CGer) and not node.op.destructive:
return [cger_inplace(*node.inputs)]
# ##### ####### ####### # ##### ####### #######
# GEMV # GEMV
# ##### ####### ####### # ##### ####### #######
...@@ -697,48 +670,3 @@ int main() { ...@@ -697,48 +670,3 @@ int main() {
check_force_gemv_init._force_init_beta = None check_force_gemv_init._force_init_beta = None
@node_rewriter([gemv_inplace, gemv_no_inplace])
def use_c_gemv(fgraph, node):
if not config.blas__ldflags:
return
# Only float32 and float64 are supported for now.
if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"):
return [cgemv_no_inplace(*node.inputs)]
if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"):
return [cgemv_inplace(*node.inputs)]
@node_rewriter([CGemv(inplace=False)])
def make_c_gemv_destructive(fgraph, node):
if isinstance(node.op, CGemv) and not node.op.inplace:
inputs = list(node.inputs)
dest = inputs[0]
if (
dest.owner
and isinstance(dest.owner.op, at.AllocEmpty)
and len(fgraph.clients[dest]) > 1
):
inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs)
return [cgemv_inplace(*inputs)]
# ##### ####### #######
# Optimizers
# ##### ####### #######
blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
"fast_run",
"inplace",
"c_blas",
position=70.0,
)
...@@ -4,16 +4,7 @@ Implementations of BLAS Ops based on scipy's BLAS bindings. ...@@ -4,16 +4,7 @@ Implementations of BLAS Ops based on scipy's BLAS bindings.
import numpy as np import numpy as np
from pytensor.graph.rewriting.basic import in2out from pytensor.tensor.blas import Ger, have_fblas
from pytensor.tensor.blas import (
Ger,
blas_optdb,
ger,
ger_destructive,
have_fblas,
node_rewriter,
optdb,
)
if have_fblas: if have_fblas:
...@@ -56,36 +47,3 @@ class ScipyGer(Ger): ...@@ -56,36 +47,3 @@ class ScipyGer(Ger):
scipy_ger_no_inplace = ScipyGer(False) scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True) scipy_ger_inplace = ScipyGer(True)
@node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node):
if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)]
@node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)]
use_scipy_blas = in2out(use_scipy_ger)
make_scipy_blas_destructive = in2out(make_ger_destructive)
if have_fblas:
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks, but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=70.0,
)
import pytensor.tensor.rewriting.basic import pytensor.tensor.rewriting.basic
import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops import pytensor.tensor.rewriting.extra_ops
......
差异被折叠。
from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor import basic as at
from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive
from pytensor.tensor.blas_c import (
CGemv,
CGer,
cgemv_inplace,
cgemv_no_inplace,
cger_inplace,
)
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
@node_rewriter([ger, ger_destructive])
def use_c_ger(fgraph, node):
if not config.blas__ldflags:
return
# Only float32 and float64 are supported for now.
if node.op == ger and node.outputs[0].dtype in ("float32", "float64"):
return [CGer(False)(*node.inputs)]
if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"):
return [CGer(True)(*node.inputs)]
@node_rewriter([CGer(False)])
def make_c_ger_destructive(fgraph, node):
if isinstance(node.op, CGer) and not node.op.destructive:
return [cger_inplace(*node.inputs)]
@node_rewriter([gemv_inplace, gemv_no_inplace])
def use_c_gemv(fgraph, node):
if not config.blas__ldflags:
return
# Only float32 and float64 are supported for now.
if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"):
return [cgemv_no_inplace(*node.inputs)]
if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"):
return [cgemv_inplace(*node.inputs)]
@node_rewriter([CGemv(inplace=False)])
def make_c_gemv_destructive(fgraph, node):
if isinstance(node.op, CGemv) and not node.op.inplace:
inputs = list(node.inputs)
dest = inputs[0]
if (
dest.owner
and isinstance(dest.owner.op, at.AllocEmpty)
and len(fgraph.clients[dest]) > 1
):
inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs)
return [cgemv_inplace(*inputs)]
blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
"fast_run",
"inplace",
"c_blas",
position=70.0,
)
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.blas import ger, ger_destructive, have_fblas
from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
@node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node):
if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)]
@node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)]
use_scipy_blas = in2out(use_scipy_ger)
make_scipy_blas_destructive = in2out(make_ger_destructive)
if have_fblas:
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks, but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=70.0,
)
...@@ -44,12 +44,11 @@ from pytensor.tensor.blas import ( ...@@ -44,12 +44,11 @@ from pytensor.tensor.blas import (
gemv_no_inplace, gemv_no_inplace,
ger, ger,
ger_destructive, ger_destructive,
local_dot22_to_dot22scalar,
local_gemm_to_ger,
res_is_a, res_is_a,
) )
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt
from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
from pytensor.tensor.type import ( from pytensor.tensor.type import (
cmatrix, cmatrix,
col, col,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论