提交 f15258d9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reorder functions in numba/dispatch/basic.py

Helpers before dispatchers
上级 351ce53e
......@@ -166,6 +166,55 @@ def create_arg_string(x):
return args
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef):
casted = typ.instance_type
elif isinstance(typ, numba.types.DTypeSpec):
casted = typ.dtype
else:
casted = typ
sig = casted(casted, typ)
def codegen(context, builder, signature, args):
val, _ = args
context.nrt.incref(builder, signature.return_type, val)
return val
return sig, codegen
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):
@numba_njit(inline="always")
def inputs_cast(x):
return x
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
return inputs_cast
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
......@@ -231,6 +280,22 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
return generate_fallback_impl(op, node, storage_map, **kwargs)
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="numba_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
type_conversion_fn=numba_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
......@@ -263,22 +328,6 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
return opfromgraph
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="numba_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
type_conversion_fn=numba_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
......@@ -296,55 +345,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopy
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef):
casted = typ.instance_type
elif isinstance(typ, numba.types.DTypeSpec):
casted = typ.dtype
else:
casted = typ
sig = casted(casted, typ)
def codegen(context, builder, signature, args):
val, _ = args
context.nrt.incref(builder, signature.return_type, val)
return val
return sig, codegen
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):
@numba_njit(inline="always")
def inputs_cast(x):
return x
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
return inputs_cast
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论