"""Common utilities for Numba operations"""
|
from __future__ import annotations
|
|
import types
|
from typing import (
|
TYPE_CHECKING,
|
Callable,
|
)
|
|
import numpy as np
|
|
from pandas.compat._optional import import_optional_dependency
|
from pandas.errors import NumbaUtilError
|
|
GLOBAL_USE_NUMBA: bool = False
|
|
|
def maybe_use_numba(engine: str | None) -> bool:
|
"""Signal whether to use numba routines."""
|
return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
|
|
|
def set_use_numba(enable: bool = False) -> None:
|
global GLOBAL_USE_NUMBA
|
if enable:
|
import_optional_dependency("numba")
|
GLOBAL_USE_NUMBA = enable
|
|
|
def get_jit_arguments(
|
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
|
) -> dict[str, bool]:
|
"""
|
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
|
|
Parameters
|
----------
|
engine_kwargs : dict, default None
|
user passed keyword arguments for numba.JIT
|
kwargs : dict, default None
|
user passed keyword arguments to pass into the JITed function
|
|
Returns
|
-------
|
dict[str, bool]
|
nopython, nogil, parallel
|
|
Raises
|
------
|
NumbaUtilError
|
"""
|
if engine_kwargs is None:
|
engine_kwargs = {}
|
|
nopython = engine_kwargs.get("nopython", True)
|
if kwargs and nopython:
|
raise NumbaUtilError(
|
"numba does not support kwargs with nopython=True: "
|
"https://github.com/numba/numba/issues/2916"
|
)
|
nogil = engine_kwargs.get("nogil", False)
|
parallel = engine_kwargs.get("parallel", False)
|
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
|
|
|
def jit_user_function(func: Callable) -> Callable:
|
"""
|
If user function is not jitted already, mark the user's function
|
as jitable.
|
|
Parameters
|
----------
|
func : function
|
user defined function
|
|
Returns
|
-------
|
function
|
Numba JITed function, or function marked as JITable by numba
|
"""
|
if TYPE_CHECKING:
|
import numba
|
else:
|
numba = import_optional_dependency("numba")
|
|
if numba.extending.is_jitted(func):
|
# Don't jit a user passed jitted function
|
numba_func = func
|
elif getattr(np, func.__name__, False) is func or isinstance(
|
func, types.BuiltinFunctionType
|
):
|
# Not necessary to jit builtins or np functions
|
# This will mess up register_jitable
|
numba_func = func
|
else:
|
numba_func = numba.extending.register_jitable(func)
|
|
return numba_func
|