import numpy as np
|
import pytest
|
|
from pandas.compat import HAS_PYARROW
|
|
from pandas.core.dtypes.cast import find_common_type
|
|
import pandas as pd
|
import pandas._testing as tm
|
from pandas.util.version import Version
|
|
|
@pytest.mark.parametrize(
|
"to_concat_dtypes, result_dtype",
|
[
|
# same types
|
([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)),
|
([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)),
|
([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)),
|
([("python", np.nan), ("python", np.nan)], ("python", np.nan)),
|
# pyarrow preference
|
([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)),
|
# NA preference
|
([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)),
|
],
|
)
|
def test_concat_series(request, to_concat_dtypes, result_dtype):
|
if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW:
|
pytest.skip("Could not import 'pyarrow'")
|
|
ser_list = [
|
pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value))
|
for storage, na_value in to_concat_dtypes
|
]
|
|
result = pd.concat(ser_list, ignore_index=True)
|
expected = pd.Series(
|
["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype)
|
)
|
tm.assert_series_equal(result, expected)
|
|
# order doesn't matter for result
|
result = pd.concat(ser_list[::1], ignore_index=True)
|
tm.assert_series_equal(result, expected)
|
|
|
def test_concat_with_object(string_dtype_arguments):
|
# _get_common_dtype cannot inspect values, so object dtype with strings still
|
# results in object dtype
|
result = pd.concat(
|
[
|
pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)),
|
pd.Series(["a", "b", None], dtype=object),
|
]
|
)
|
assert result.dtype == np.dtype("object")
|
|
|
def test_concat_with_numpy(string_dtype_arguments):
|
# common type with a numpy string dtype always preserves the pandas string dtype
|
dtype = pd.StringDtype(*string_dtype_arguments)
|
assert find_common_type([dtype, np.dtype("U")]) == dtype
|
assert find_common_type([np.dtype("U"), dtype]) == dtype
|
assert find_common_type([dtype, np.dtype("U10")]) == dtype
|
assert find_common_type([np.dtype("U10"), dtype]) == dtype
|
|
# with any other numpy dtype -> object
|
assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object")
|
assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object")
|
|
if Version(np.__version__) >= Version("2"):
|
assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype
|
assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype
|