reconnect moved files to git repo
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
645
venv/lib/python3.11/site-packages/narwhals/_polars/dataframe.py
Normal file
645
venv/lib/python3.11/site-packages/narwhals/_polars/dataframe.py
Normal file
@ -0,0 +1,645 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping, Sequence, Sized
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.utils import (
|
||||
catch_polars_exception,
|
||||
extract_args_kwargs,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
_into_arrow_table,
|
||||
check_columns_exist,
|
||||
convert_str_slice_to_int_slice,
|
||||
is_compliant_series,
|
||||
is_index_selector,
|
||||
is_range,
|
||||
is_sequence_like,
|
||||
is_slice_index,
|
||||
is_slice_none,
|
||||
parse_columns_to_drop,
|
||||
requires,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import DataFrame, LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.typing import (
|
||||
JoinStrategy,
|
||||
MultiColSelector,
|
||||
MultiIndexSelector,
|
||||
PivotAgg,
|
||||
SingleIndexSelector,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
Method: TypeAlias = "Callable[..., R]"
|
||||
"""Generic alias representing all methods implemented via `__getattr__`.
|
||||
|
||||
Where `R` is the return type.
|
||||
"""
|
||||
|
||||
# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
|
||||
INHERITED_METHODS = frozenset(
|
||||
[
|
||||
"clone",
|
||||
"drop_nulls",
|
||||
"estimated_size",
|
||||
"explode",
|
||||
"filter",
|
||||
"gather_every",
|
||||
"head",
|
||||
"is_unique",
|
||||
"item",
|
||||
"iter_rows",
|
||||
"join_asof",
|
||||
"rename",
|
||||
"row",
|
||||
"rows",
|
||||
"sample",
|
||||
"select",
|
||||
"sink_parquet",
|
||||
"sort",
|
||||
"tail",
|
||||
"to_arrow",
|
||||
"to_pandas",
|
||||
"unique",
|
||||
"with_columns",
|
||||
"write_csv",
|
||||
"write_parquet",
|
||||
]
|
||||
)
|
||||
|
||||
NativePolarsFrame = TypeVar("NativePolarsFrame", pl.DataFrame, pl.LazyFrame)
|
||||
|
||||
|
||||
class PolarsBaseFrame(Generic[NativePolarsFrame]):
|
||||
drop_nulls: Method[Self]
|
||||
explode: Method[Self]
|
||||
filter: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
join_asof: Method[Self]
|
||||
rename: Method[Self]
|
||||
select: Method[Self]
|
||||
sort: Method[Self]
|
||||
tail: Method[Self]
|
||||
unique: Method[Self]
|
||||
with_columns: Method[Self]
|
||||
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: NativePolarsFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: NativePolarsFrame = df
|
||||
self._version = version
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
def _validate_backend_version(self) -> None:
|
||||
"""Raise if installed version below `nw._utils.MIN_VERSIONS`.
|
||||
|
||||
**Only use this when moving between backends.**
|
||||
Otherwise, the validation will have taken place already.
|
||||
"""
|
||||
_ = self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def native(self) -> NativePolarsFrame:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return self.native.columns
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.POLARS:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def _with_native(self, df: NativePolarsFrame) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: NativePolarsFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist( # pragma: no cover
|
||||
subset, available=self.columns
|
||||
)
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: Any) -> Self:
|
||||
return self.select(*exprs)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
return self.collect_schema()
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_native = (
|
||||
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.join(
|
||||
other=other.native,
|
||||
how=how_native, # type: ignore[arg-type]
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffix=suffix,
|
||||
)
|
||||
)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
variable_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.unpivot(
|
||||
on=on, index=index, variable_name=variable_name, value_name=value_name
|
||||
)
|
||||
)
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
df = self.native
|
||||
schema = df.schema if self._backend_version < (1,) else df.collect_schema()
|
||||
return {
|
||||
name: native_to_narwhals_dtype(dtype, self._version)
|
||||
for name, dtype in schema.items()
|
||||
}
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
frame = self.native
|
||||
if order_by is None:
|
||||
result = frame.with_row_index(name)
|
||||
else:
|
||||
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
|
||||
result = frame.select(
|
||||
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
|
||||
)
|
||||
|
||||
return self._with_native(result)
|
||||
|
||||
|
||||
class PolarsDataFrame(PolarsBaseFrame[pl.DataFrame]):
|
||||
clone: Method[Self]
|
||||
collect: Method[CompliantDataFrameAny]
|
||||
estimated_size: Method[int | float]
|
||||
gather_every: Method[Self]
|
||||
item: Method[Any]
|
||||
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
|
||||
is_unique: Method[PolarsSeries]
|
||||
row: Method[tuple[Any, ...]]
|
||||
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
|
||||
sample: Method[Self]
|
||||
to_arrow: Method[pa.Table]
|
||||
to_pandas: Method[pd.DataFrame]
|
||||
# NOTE: `write_csv` requires an `@overload` for `str | None`
|
||||
# Can't do that here 😟
|
||||
write_csv: Method[Any]
|
||||
write_parquet: Method[None]
|
||||
|
||||
# CompliantDataFrame
|
||||
_evaluate_aliases: Any
|
||||
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
|
||||
if context._implementation._backend_version() >= (1, 3):
|
||||
native = pl.DataFrame(data)
|
||||
else: # pragma: no cover
|
||||
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: Mapping[str, DType] | Schema | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pl_schema = Schema(schema).to_polars() if schema is not None else schema
|
||||
return cls.from_native(pl.from_dict(data, pl_schema), context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
|
||||
return isinstance(obj, pl.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pl_schema = (
|
||||
Schema(schema).to_polars()
|
||||
if isinstance(schema, (Mapping, Schema))
|
||||
else schema
|
||||
)
|
||||
return cls.from_native(pl.from_numpy(data, pl_schema), context=context)
|
||||
|
||||
def to_narwhals(self) -> DataFrame[pl.DataFrame]:
|
||||
return self._version.dataframe(self, level="full")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsDataFrame"
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: pl.DataFrame) -> Self: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: T) -> T: ...
|
||||
|
||||
def _from_native_object(
|
||||
self, obj: pl.Series | pl.DataFrame | T
|
||||
) -> Self | PolarsSeries | T:
|
||||
if isinstance(obj, pl.Series):
|
||||
return PolarsSeries.from_native(obj, context=self)
|
||||
if self._is_native(obj):
|
||||
return self._with_native(obj)
|
||||
# scalar
|
||||
return obj
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS: # pragma: no cover
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
try:
|
||||
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
|
||||
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
|
||||
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
|
||||
raise ColumnNotFoundError(msg) from e
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
return func
|
||||
|
||||
def __array__(
|
||||
self, dtype: Any | None = None, *, copy: bool | None = None
|
||||
) -> _2DArray:
|
||||
if self._backend_version < (0, 20, 28) and copy is not None:
|
||||
msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
|
||||
raise NotImplementedError(msg)
|
||||
if self._backend_version < (0, 20, 28):
|
||||
return self.native.__array__(dtype)
|
||||
return self.native.__array__(dtype)
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
|
||||
return self.native.to_numpy()
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
return self.native.shape
|
||||
|
||||
def __getitem__( # noqa: C901, PLR0912
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[PolarsSeries],
|
||||
MultiColSelector[PolarsSeries],
|
||||
],
|
||||
) -> Any:
|
||||
rows, columns = item
|
||||
if self._backend_version > (0, 20, 30):
|
||||
rows_native = rows.native if is_compliant_series(rows) else rows
|
||||
columns_native = columns.native if is_compliant_series(columns) else columns
|
||||
selector = rows_native, columns_native
|
||||
selected = self.native.__getitem__(selector) # type: ignore[index]
|
||||
return self._from_native_object(selected)
|
||||
else: # pragma: no cover
|
||||
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
|
||||
# Polars version we support
|
||||
# This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
|
||||
rows = list(rows) if isinstance(rows, tuple) else rows
|
||||
columns = list(columns) if isinstance(columns, tuple) else columns
|
||||
if is_numpy_array_1d(columns):
|
||||
columns = columns.tolist()
|
||||
|
||||
native = self.native
|
||||
if not is_slice_none(columns):
|
||||
if isinstance(columns, Sized) and len(columns) == 0:
|
||||
return self.select()
|
||||
if is_index_selector(columns):
|
||||
if is_slice_index(columns) or is_range(columns):
|
||||
native = native.select(
|
||||
self.columns[slice(columns.start, columns.stop, columns.step)]
|
||||
)
|
||||
elif is_compliant_series(columns):
|
||||
native = native[:, columns.native.to_list()]
|
||||
else:
|
||||
native = native[:, columns]
|
||||
elif isinstance(columns, slice):
|
||||
native = native.select(
|
||||
self.columns[
|
||||
slice(*convert_str_slice_to_int_slice(columns, self.columns))
|
||||
]
|
||||
)
|
||||
elif is_compliant_series(columns):
|
||||
native = native.select(columns.native.to_list())
|
||||
elif is_sequence_like(columns):
|
||||
native = native.select(columns)
|
||||
else:
|
||||
msg = f"Unreachable code, got unexpected type: {type(columns)}"
|
||||
raise AssertionError(msg)
|
||||
|
||||
if not is_slice_none(rows):
|
||||
if isinstance(rows, int):
|
||||
native = native[[rows], :]
|
||||
elif isinstance(rows, (slice, range)):
|
||||
native = native[rows, :]
|
||||
elif is_compliant_series(rows):
|
||||
native = native[rows.native, :]
|
||||
elif is_sequence_like(rows):
|
||||
native = native[rows, :]
|
||||
else:
|
||||
msg = f"Unreachable code, got unexpected type: {type(rows)}"
|
||||
raise AssertionError(msg)
|
||||
|
||||
return self._with_native(native)
|
||||
|
||||
def get_column(self, name: str) -> PolarsSeries:
|
||||
return PolarsSeries.from_native(self.native.get_column(name), context=self)
|
||||
|
||||
def iter_columns(self) -> Iterator[PolarsSeries]:
|
||||
for series in self.native.iter_columns():
|
||||
yield PolarsSeries.from_native(series, context=self)
|
||||
|
||||
def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny:
|
||||
if backend is None or backend is Implementation.POLARS:
|
||||
return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
|
||||
elif backend is Implementation.DUCKDB:
|
||||
import duckdb # ignore-banned-import
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
|
||||
# NOTE: (F841) is a false positive
|
||||
df = self.native # noqa: F841
|
||||
return DuckDBLazyFrame(
|
||||
duckdb.table("df"), validate_backend_version=True, version=self._version
|
||||
)
|
||||
elif backend is Implementation.DASK:
|
||||
import dask.dataframe as dd # ignore-banned-import
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
return DaskLazyFrame(
|
||||
dd.from_pandas(self.native.to_pandas()),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
elif backend.is_ibis():
|
||||
import ibis # ignore-banned-import
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
|
||||
return IbisLazyFrame(
|
||||
ibis.memtable(self.native, columns=self.columns),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
|
||||
if as_series:
|
||||
return {
|
||||
name: PolarsSeries.from_native(col, context=self)
|
||||
for name, col in self.native.to_dict().items()
|
||||
}
|
||||
else:
|
||||
return self.native.to_dict(as_series=False)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
|
||||
) -> PolarsGroupBy:
|
||||
from narwhals._polars.group_by import PolarsGroupBy
|
||||
|
||||
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(to_drop))
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def pivot(
|
||||
self,
|
||||
on: Sequence[str],
|
||||
*,
|
||||
index: Sequence[str] | None,
|
||||
values: Sequence[str] | None,
|
||||
aggregate_function: PivotAgg | None,
|
||||
sort_columns: bool,
|
||||
separator: str,
|
||||
) -> Self:
|
||||
try:
|
||||
result = self.native.pivot(
|
||||
on,
|
||||
index=index,
|
||||
values=values,
|
||||
aggregate_function=aggregate_function,
|
||||
sort_columns=sort_columns,
|
||||
separator=separator,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
return self._from_native_object(result)
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
return self.native
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
try:
|
||||
return super().join(
|
||||
other=other, how=how, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
|
||||
class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
|
||||
# CompliantLazyFrame
|
||||
sink_parquet: Method[None]
|
||||
_evaluate_expr: Any
|
||||
_evaluate_aliases: Any
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
|
||||
return isinstance(obj, pl.LazyFrame)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsLazyFrame"
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS: # pragma: no cover
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
try:
|
||||
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
|
||||
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
|
||||
raise ColumnNotFoundError(str(e)) from e
|
||||
|
||||
return func
|
||||
|
||||
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
|
||||
yield from self.collect(self._implementation).iter_columns()
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
try:
|
||||
return super().collect_schema()
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def collect(
|
||||
self, backend: Implementation | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
try:
|
||||
result = self.native.collect(**kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
if backend is None or backend is Implementation.POLARS:
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
result.to_arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
|
||||
) -> PolarsLazyGroupBy:
|
||||
from narwhals._polars.group_by import PolarsLazyGroupBy
|
||||
|
||||
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(self.native.drop(columns))
|
||||
return self._with_native(self.native.drop(columns, strict=strict))
|
||||
488
venv/lib/python3.11/site-packages/narwhals/_polars/expr.py
Normal file
488
venv/lib/python3.11/site-packages/narwhals/_polars/expr.py
Normal file
@ -0,0 +1,488 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._polars.utils import (
|
||||
extract_args_kwargs,
|
||||
extract_native,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation, requires
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._polars.dataframe import Method
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
|
||||
class PolarsExpr:
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
def __init__(self, expr: pl.Expr, version: Version) -> None:
|
||||
self._native_expr = expr
|
||||
self._version = version
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Expr:
|
||||
return self._native_expr
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsExpr"
|
||||
|
||||
def _with_native(self, expr: pl.Expr) -> Self:
|
||||
return self.__class__(expr, self._version)
|
||||
|
||||
@classmethod
|
||||
def _from_series(cls, series: Any) -> Self:
|
||||
return cls(series.native, series._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
# Let Polars do its thing.
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
|
||||
name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
|
||||
return {name: min_samples}
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
|
||||
return self._with_native(self.native.cast(dtype_pl))
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
native = self.native.ewm_mean(
|
||||
com=com,
|
||||
span=span,
|
||||
half_life=half_life,
|
||||
alpha=alpha,
|
||||
adjust=adjust,
|
||||
ignore_nulls=ignore_nulls,
|
||||
**self._renamed_min_periods(min_samples),
|
||||
)
|
||||
if self._backend_version < (1,): # pragma: no cover
|
||||
native = pl.when(~self.native.is_null()).then(native).otherwise(None)
|
||||
return self._with_native(native)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
if self._backend_version >= (1, 18):
|
||||
native = self.native.is_nan()
|
||||
else: # pragma: no cover
|
||||
native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
|
||||
return self._with_native(native)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
if self._backend_version < (1, 9):
|
||||
if order_by:
|
||||
msg = "`order_by` in Polars requires version 1.10 or greater"
|
||||
raise NotImplementedError(msg)
|
||||
native = self.native.over(partition_by or pl.lit(1))
|
||||
else:
|
||||
native = self.native.over(
|
||||
partition_by or pl.lit(1), order_by=order_by or None
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_var(
|
||||
window_size=window_size, center=center, ddof=ddof, **kwds
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_std(
|
||||
window_size=window_size, center=center, ddof=ddof, **kwds
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
|
||||
return self._with_native(native)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
|
||||
return self._with_native(native)
|
||||
|
||||
def map_batches(
|
||||
self, function: Callable[[Any], Any], return_dtype: IntoDType | None
|
||||
) -> Self:
|
||||
return_dtype_pl = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
native = self.native.map_batches(function, return_dtype_pl)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self:
|
||||
return_dtype_pl = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl)
|
||||
return self._with_native(native)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __ge__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__ge__(extract_native(other)))
|
||||
|
||||
def __gt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__gt__(extract_native(other)))
|
||||
|
||||
def __le__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__le__(extract_native(other)))
|
||||
|
||||
def __lt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__lt__(extract_native(other)))
|
||||
|
||||
def __and__(self, other: PolarsExpr | bool | Any) -> Self:
|
||||
return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __or__(self, other: PolarsExpr | bool | Any) -> Self:
|
||||
return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__add__(extract_native(other)))
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__sub__(extract_native(other)))
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__mul__(extract_native(other)))
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__pow__(extract_native(other)))
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__truediv__(extract_native(other)))
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__floordiv__(extract_native(other)))
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__mod__(extract_native(other)))
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_native(self.native.__invert__())
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_native(self.native.cum_count(reverse=reverse))
|
||||
|
||||
def __narwhals_expr__(self) -> None: ...
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
@property
|
||||
def dt(self) -> PolarsExprDateTimeNamespace:
|
||||
return PolarsExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def str(self) -> PolarsExprStringNamespace:
|
||||
return PolarsExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def cat(self) -> PolarsExprCatNamespace:
|
||||
return PolarsExprCatNamespace(self)
|
||||
|
||||
@property
|
||||
def name(self) -> PolarsExprNameNamespace:
|
||||
return PolarsExprNameNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> PolarsExprListNamespace:
|
||||
return PolarsExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> PolarsExprStructNamespace:
|
||||
return PolarsExprStructNamespace(self)
|
||||
|
||||
# CompliantExpr
|
||||
_alias_output_names: Any
|
||||
_evaluate_aliases: Any
|
||||
_evaluate_output_names: Any
|
||||
_is_multi_output_unnamed: Any
|
||||
__call__: Any
|
||||
from_column_names: Any
|
||||
from_column_indices: Any
|
||||
_eval_names_indices: Any
|
||||
|
||||
# Polars
|
||||
abs: Method[Self]
|
||||
all: Method[Self]
|
||||
any: Method[Self]
|
||||
alias: Method[Self]
|
||||
arg_max: Method[Self]
|
||||
arg_min: Method[Self]
|
||||
arg_true: Method[Self]
|
||||
clip: Method[Self]
|
||||
count: Method[Self]
|
||||
cum_max: Method[Self]
|
||||
cum_min: Method[Self]
|
||||
cum_prod: Method[Self]
|
||||
cum_sum: Method[Self]
|
||||
diff: Method[Self]
|
||||
drop_nulls: Method[Self]
|
||||
exp: Method[Self]
|
||||
fill_null: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
is_finite: Method[Self]
|
||||
is_first_distinct: Method[Self]
|
||||
is_in: Method[Self]
|
||||
is_last_distinct: Method[Self]
|
||||
is_null: Method[Self]
|
||||
is_unique: Method[Self]
|
||||
kurtosis: Method[Self]
|
||||
len: Method[Self]
|
||||
log: Method[Self]
|
||||
max: Method[Self]
|
||||
mean: Method[Self]
|
||||
median: Method[Self]
|
||||
min: Method[Self]
|
||||
mode: Method[Self]
|
||||
n_unique: Method[Self]
|
||||
null_count: Method[Self]
|
||||
quantile: Method[Self]
|
||||
rank: Method[Self]
|
||||
round: Method[Self]
|
||||
sample: Method[Self]
|
||||
shift: Method[Self]
|
||||
skew: Method[Self]
|
||||
sqrt: Method[Self]
|
||||
std: Method[Self]
|
||||
sum: Method[Self]
|
||||
sort: Method[Self]
|
||||
tail: Method[Self]
|
||||
unique: Method[Self]
|
||||
var: Method[Self]
|
||||
|
||||
|
||||
class PolarsExprNamespace:
|
||||
def __init__(self, expr: PolarsExpr) -> None:
|
||||
self._expr = expr
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsExpr:
|
||||
return self._expr
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Expr:
|
||||
return self._expr.native
|
||||
|
||||
|
||||
class PolarsExprDateTimeNamespace(PolarsExprNamespace):
|
||||
def truncate(self, every: str) -> PolarsExpr:
|
||||
Interval.parse(every) # Ensure consistent error message is raised.
|
||||
return self.compliant._with_native(self.native.dt.truncate(every))
|
||||
|
||||
def offset_by(self, by: str) -> PolarsExpr:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse_no_constraints(by)
|
||||
return self.compliant._with_native(self.native.dt.offset_by(by))
|
||||
|
||||
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
|
||||
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self.compliant._with_native(
|
||||
getattr(self.native.dt, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
to_string: Method[PolarsExpr]
|
||||
replace_time_zone: Method[PolarsExpr]
|
||||
convert_time_zone: Method[PolarsExpr]
|
||||
timestamp: Method[PolarsExpr]
|
||||
date: Method[PolarsExpr]
|
||||
year: Method[PolarsExpr]
|
||||
month: Method[PolarsExpr]
|
||||
day: Method[PolarsExpr]
|
||||
hour: Method[PolarsExpr]
|
||||
minute: Method[PolarsExpr]
|
||||
second: Method[PolarsExpr]
|
||||
millisecond: Method[PolarsExpr]
|
||||
microsecond: Method[PolarsExpr]
|
||||
nanosecond: Method[PolarsExpr]
|
||||
ordinal_day: Method[PolarsExpr]
|
||||
weekday: Method[PolarsExpr]
|
||||
total_minutes: Method[PolarsExpr]
|
||||
total_seconds: Method[PolarsExpr]
|
||||
total_milliseconds: Method[PolarsExpr]
|
||||
total_microseconds: Method[PolarsExpr]
|
||||
total_nanoseconds: Method[PolarsExpr]
|
||||
|
||||
|
||||
class PolarsExprStringNamespace(PolarsExprNamespace):
|
||||
def zfill(self, width: int) -> PolarsExpr:
|
||||
backend_version = self.compliant._backend_version
|
||||
native_result = self.native.str.zfill(width)
|
||||
|
||||
if backend_version < (0, 20, 5): # pragma: no cover
|
||||
# Reason:
|
||||
# `TypeError: argument 'length': 'Expr' object cannot be interpreted as an integer`
|
||||
# in `native_expr.str.slice(1, length)`
|
||||
msg = "`zfill` is only available in 'polars>=0.20.5', found version '0.20.4'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if backend_version <= (1, 30, 0):
|
||||
length = self.native.str.len_chars()
|
||||
less_than_width = length < width
|
||||
plus = "+"
|
||||
starts_with_plus = self.native.str.starts_with(plus)
|
||||
native_result = (
|
||||
pl.when(starts_with_plus & less_than_width)
|
||||
.then(
|
||||
self.native.str.slice(1, length)
|
||||
.str.zfill(width - 1)
|
||||
.str.pad_start(width, plus)
|
||||
)
|
||||
.otherwise(native_result)
|
||||
)
|
||||
|
||||
return self.compliant._with_native(native_result)
|
||||
|
||||
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
|
||||
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self.compliant._with_native(
|
||||
getattr(self.native.str, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
len_chars: Method[PolarsExpr]
|
||||
replace: Method[PolarsExpr]
|
||||
replace_all: Method[PolarsExpr]
|
||||
strip_chars: Method[PolarsExpr]
|
||||
starts_with: Method[PolarsExpr]
|
||||
ends_with: Method[PolarsExpr]
|
||||
contains: Method[PolarsExpr]
|
||||
slice: Method[PolarsExpr]
|
||||
split: Method[PolarsExpr]
|
||||
to_date: Method[PolarsExpr]
|
||||
to_datetime: Method[PolarsExpr]
|
||||
to_lowercase: Method[PolarsExpr]
|
||||
to_uppercase: Method[PolarsExpr]
|
||||
|
||||
|
||||
class PolarsExprCatNamespace(PolarsExprNamespace):
|
||||
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
|
||||
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self.compliant._with_native(
|
||||
getattr(self.native.cat, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
get_categories: Method[PolarsExpr]
|
||||
|
||||
|
||||
class PolarsExprNameNamespace(PolarsExprNamespace):
|
||||
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
|
||||
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self.compliant._with_native(
|
||||
getattr(self.native.name, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
keep: Method[PolarsExpr]
|
||||
map: Method[PolarsExpr]
|
||||
prefix: Method[PolarsExpr]
|
||||
suffix: Method[PolarsExpr]
|
||||
to_lowercase: Method[PolarsExpr]
|
||||
to_uppercase: Method[PolarsExpr]
|
||||
|
||||
|
||||
class PolarsExprListNamespace(PolarsExprNamespace):
|
||||
def len(self) -> PolarsExpr:
|
||||
native_expr = self.compliant._native_expr
|
||||
native_result = native_expr.list.len()
|
||||
|
||||
if self.compliant._backend_version < (1, 16): # pragma: no cover
|
||||
native_result = (
|
||||
pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
|
||||
)
|
||||
elif self.compliant._backend_version < (1, 17): # pragma: no cover
|
||||
native_result = native_result.cast(pl.UInt32())
|
||||
|
||||
return self.compliant._with_native(native_result)
|
||||
|
||||
# TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added
|
||||
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover
|
||||
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self.compliant._with_native(
|
||||
getattr(self.native.list, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsExprStructNamespace(PolarsExprNamespace):
|
||||
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover
|
||||
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self.compliant._with_native(
|
||||
getattr(self.native.struct, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
field: Method[PolarsExpr]
|
||||
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from narwhals._utils import is_sequence_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from polars.dataframe.group_by import GroupBy as NativeGroupBy
|
||||
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
|
||||
|
||||
class PolarsGroupBy:
|
||||
_compliant_frame: PolarsDataFrame
|
||||
_grouped: NativeGroupBy
|
||||
_drop_null_keys: bool
|
||||
_output_names: Sequence[str]
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsDataFrame:
|
||||
return self._compliant_frame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PolarsDataFrame,
|
||||
keys: Sequence[PolarsExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._keys = list(keys)
|
||||
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
||||
self._grouped = (
|
||||
self.compliant.native.group_by(keys)
|
||||
if is_sequence_of(keys, str)
|
||||
else self.compliant.native.group_by(arg.native for arg in keys)
|
||||
)
|
||||
|
||||
def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
|
||||
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
||||
return self.compliant._with_native(agg_result)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
|
||||
for key, df in self._grouped:
|
||||
yield tuple(cast("str", key)), self.compliant._with_native(df)
|
||||
|
||||
|
||||
class PolarsLazyGroupBy:
|
||||
_compliant_frame: PolarsLazyFrame
|
||||
_grouped: NativeLazyGroupBy
|
||||
_drop_null_keys: bool
|
||||
_output_names: Sequence[str]
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsLazyFrame:
|
||||
return self._compliant_frame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PolarsLazyFrame,
|
||||
keys: Sequence[PolarsExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._keys = list(keys)
|
||||
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
||||
self._grouped = (
|
||||
self.compliant.native.group_by(keys)
|
||||
if is_sequence_of(keys, str)
|
||||
else self.compliant.native.group_by(arg.native for arg in keys)
|
||||
)
|
||||
|
||||
def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
|
||||
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
||||
return self.compliant._with_native(agg_result)
|
||||
257
venv/lib/python3.11/site-packages/narwhals/_polars/namespace.py
Normal file
257
venv/lib/python3.11/site-packages/narwhals/_polars/namespace.py
Normal file
@ -0,0 +1,257 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype
|
||||
from narwhals._utils import Implementation, requires
|
||||
from narwhals.dependencies import is_numpy_array_2d
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from datetime import timezone
|
||||
|
||||
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
|
||||
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.typing import FrameT
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.typing import Into1DArray, IntoDType, TimeUnit, _2DArray
|
||||
|
||||
|
||||
class PolarsNamespace:
|
||||
all: Method[PolarsExpr]
|
||||
coalesce: Method[PolarsExpr]
|
||||
col: Method[PolarsExpr]
|
||||
exclude: Method[PolarsExpr]
|
||||
sum_horizontal: Method[PolarsExpr]
|
||||
min_horizontal: Method[PolarsExpr]
|
||||
max_horizontal: Method[PolarsExpr]
|
||||
|
||||
# NOTE: `pyright` accepts, `mypy` doesn't highlight the issue
|
||||
# error: Type argument "PolarsExpr" of "CompliantWhen" must be a subtype of "CompliantExpr[Any, Any]"
|
||||
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] # type: ignore[type-var]
|
||||
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._expr(getattr(pl, attr)(*pos, **kwds), version=self._version)
|
||||
|
||||
return func
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[PolarsDataFrame]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[PolarsLazyFrame]:
|
||||
from narwhals._polars.dataframe import PolarsLazyFrame
|
||||
|
||||
return PolarsLazyFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[PolarsExpr]:
|
||||
return PolarsExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[PolarsSeries]:
|
||||
return PolarsSeries
|
||||
|
||||
@overload
|
||||
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
|
||||
@overload
|
||||
def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ...
|
||||
@overload
|
||||
def from_native(self, data: pl.Series, /) -> PolarsSeries: ...
|
||||
def from_native(
|
||||
self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, /
|
||||
) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries:
|
||||
if self._dataframe._is_native(data):
|
||||
return self._dataframe.from_native(data, context=self)
|
||||
elif self._series._is_native(data):
|
||||
return self._series.from_native(data, context=self)
|
||||
elif self._lazyframe._is_native(data):
|
||||
return self._lazyframe.from_native(data, context=self)
|
||||
else: # pragma: no cover
|
||||
msg = f"Unsupported type: {type(data).__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
@overload
|
||||
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ...
|
||||
|
||||
@overload
|
||||
def from_numpy(
|
||||
self,
|
||||
data: _2DArray,
|
||||
/,
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
|
||||
) -> PolarsDataFrame: ...
|
||||
|
||||
def from_numpy(
|
||||
self,
|
||||
data: Into1DArray | _2DArray,
|
||||
/,
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
|
||||
) -> PolarsDataFrame | PolarsSeries:
|
||||
if is_numpy_array_2d(data):
|
||||
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
||||
return self._series.from_numpy(data, context=self) # pragma: no cover
|
||||
|
||||
@requires.backend_version(
|
||||
(1, 0, 0), "Please use `col` for columns selection instead."
|
||||
)
|
||||
def nth(self, *indices: int) -> PolarsExpr:
|
||||
return self._expr(pl.nth(*indices), version=self._version)
|
||||
|
||||
def len(self) -> PolarsExpr:
|
||||
if self._backend_version < (0, 20, 5):
|
||||
return self._expr(pl.count().alias("len"), self._version)
|
||||
return self._expr(pl.len(), self._version)
|
||||
|
||||
def all_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
|
||||
it = (expr.fill_null(True) for expr in exprs) if ignore_nulls else iter(exprs) # noqa: FBT003
|
||||
return self._expr(pl.all_horizontal(*(expr.native for expr in it)), self._version)
|
||||
|
||||
def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
|
||||
it = (expr.fill_null(False) for expr in exprs) if ignore_nulls else iter(exprs) # noqa: FBT003
|
||||
return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version)
|
||||
|
||||
def concat(
|
||||
self,
|
||||
items: Iterable[FrameT],
|
||||
*,
|
||||
how: Literal["vertical", "horizontal", "diagonal"],
|
||||
) -> PolarsDataFrame | PolarsLazyFrame:
|
||||
result = pl.concat((item.native for item in items), how=how)
|
||||
if isinstance(result, pl.DataFrame):
|
||||
return self._dataframe(result, version=self._version)
|
||||
return self._lazyframe.from_native(result, context=self)
|
||||
|
||||
def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr:
|
||||
if dtype is not None:
|
||||
return self._expr(
|
||||
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)),
|
||||
version=self._version,
|
||||
)
|
||||
return self._expr(pl.lit(value), version=self._version)
|
||||
|
||||
def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr:
|
||||
if self._backend_version < (0, 20, 8):
|
||||
return self._expr(
|
||||
pl.sum_horizontal(e._native_expr for e in exprs)
|
||||
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
return self._expr(
|
||||
pl.mean_horizontal(e._native_expr for e in exprs), version=self._version
|
||||
)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool
|
||||
) -> PolarsExpr:
|
||||
pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs]
|
||||
|
||||
if self._backend_version < (0, 20, 6):
|
||||
null_mask = [expr.is_null() for expr in pl_exprs]
|
||||
sep = pl.lit(separator)
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = pl.any_horizontal(*null_mask)
|
||||
output_expr = pl.reduce(
|
||||
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
|
||||
pl_exprs,
|
||||
)
|
||||
result = pl.when(~null_mask_result).then(output_expr)
|
||||
else:
|
||||
init_value, *values = [
|
||||
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
|
||||
for expr, nm in zip(pl_exprs, null_mask)
|
||||
]
|
||||
separators = [
|
||||
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
|
||||
]
|
||||
|
||||
result = pl.fold( # type: ignore[assignment]
|
||||
acc=init_value,
|
||||
function=operator.add,
|
||||
exprs=[s + v for s, v in zip(separators, values)],
|
||||
)
|
||||
|
||||
return self._expr(result, version=self._version)
|
||||
|
||||
return self._expr(
|
||||
pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
# NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`)
|
||||
# 1. Others have lots of private stuff for code reuse
|
||||
# i. None of that is useful here
|
||||
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
|
||||
return cast(
|
||||
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
|
||||
PolarsSelectorNamespace(self),
|
||||
)
|
||||
|
||||
|
||||
class PolarsSelectorNamespace:
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
def __init__(self, context: _LimitedContext, /) -> None:
|
||||
self._version = context._version
|
||||
|
||||
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
|
||||
native_dtypes = [
|
||||
narwhals_to_native_dtype(dtype, self._version).__class__
|
||||
if isinstance(dtype, type) and issubclass(dtype, DType)
|
||||
else narwhals_to_native_dtype(dtype, self._version)
|
||||
for dtype in dtypes
|
||||
]
|
||||
return PolarsExpr(pl.selectors.by_dtype(native_dtypes), version=self._version)
|
||||
|
||||
def matches(self, pattern: str) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.matches(pattern=pattern), version=self._version)
|
||||
|
||||
def numeric(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.numeric(), version=self._version)
|
||||
|
||||
def boolean(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.boolean(), version=self._version)
|
||||
|
||||
def string(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.string(), version=self._version)
|
||||
|
||||
def categorical(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.categorical(), version=self._version)
|
||||
|
||||
def all(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.all(), version=self._version)
|
||||
|
||||
def datetime(
|
||||
self,
|
||||
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
||||
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
||||
) -> PolarsExpr:
|
||||
return PolarsExpr(
|
||||
pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type]
|
||||
version=self._version,
|
||||
)
|
||||
772
venv/lib/python3.11/site-packages/narwhals/_polars/series.py
Normal file
772
venv/lib/python3.11/site-packages/narwhals/_polars/series.py
Normal file
@ -0,0 +1,772 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.utils import (
|
||||
BACKEND_VERSION,
|
||||
catch_polars_exception,
|
||||
extract_args_kwargs,
|
||||
extract_native,
|
||||
narwhals_to_native_dtype,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation, requires
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from types import ModuleType
|
||||
from typing import Literal, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._polars.dataframe import Method, PolarsDataFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import Into1DArray, IntoDType, MultiIndexSelector, _1DArray
|
||||
|
||||
T = TypeVar("T")
|
||||
IncludeBreakpoint: TypeAlias = Literal[False, True]
|
||||
|
||||
|
||||
# Series methods where PolarsSeries just defers to Polars.Series directly.
|
||||
INHERITED_METHODS = frozenset(
|
||||
[
|
||||
"__add__",
|
||||
"__and__",
|
||||
"__floordiv__",
|
||||
"__invert__",
|
||||
"__iter__",
|
||||
"__mod__",
|
||||
"__mul__",
|
||||
"__or__",
|
||||
"__pow__",
|
||||
"__radd__",
|
||||
"__rand__",
|
||||
"__rfloordiv__",
|
||||
"__rmod__",
|
||||
"__rmul__",
|
||||
"__ror__",
|
||||
"__rsub__",
|
||||
"__rtruediv__",
|
||||
"__sub__",
|
||||
"__truediv__",
|
||||
"abs",
|
||||
"all",
|
||||
"any",
|
||||
"arg_max",
|
||||
"arg_min",
|
||||
"arg_true",
|
||||
"clip",
|
||||
"count",
|
||||
"cum_max",
|
||||
"cum_min",
|
||||
"cum_prod",
|
||||
"cum_sum",
|
||||
"diff",
|
||||
"drop_nulls",
|
||||
"exp",
|
||||
"fill_null",
|
||||
"filter",
|
||||
"gather_every",
|
||||
"head",
|
||||
"is_between",
|
||||
"is_finite",
|
||||
"is_first_distinct",
|
||||
"is_in",
|
||||
"is_last_distinct",
|
||||
"is_null",
|
||||
"is_sorted",
|
||||
"is_unique",
|
||||
"item",
|
||||
"kurtosis",
|
||||
"len",
|
||||
"log",
|
||||
"max",
|
||||
"mean",
|
||||
"min",
|
||||
"mode",
|
||||
"n_unique",
|
||||
"null_count",
|
||||
"quantile",
|
||||
"rank",
|
||||
"round",
|
||||
"sample",
|
||||
"shift",
|
||||
"skew",
|
||||
"sqrt",
|
||||
"std",
|
||||
"sum",
|
||||
"tail",
|
||||
"to_arrow",
|
||||
"to_frame",
|
||||
"to_list",
|
||||
"to_pandas",
|
||||
"unique",
|
||||
"var",
|
||||
"zip_with",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class PolarsSeries:
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
def __init__(self, series: pl.Series, *, version: Version) -> None:
|
||||
self._native_series: pl.Series = series
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsSeries"
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.POLARS:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls,
|
||||
data: Iterable[Any],
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
name: str = "",
|
||||
dtype: IntoDType | None = None,
|
||||
) -> Self:
|
||||
version = context._version
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, version) if dtype else None
|
||||
# NOTE: `Iterable` is fine, annotation is overly narrow
|
||||
# https://github.com/pola-rs/polars/blob/82d57a4ee41f87c11ca1b1af15488459727efdd7/py-polars/polars/series/series.py#L332-L333
|
||||
native = pl.Series(name=name, values=cast("Sequence[Any]", data), dtype=dtype_pl)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.Series | Any) -> TypeIs[pl.Series]:
|
||||
return isinstance(obj, pl.Series)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: pl.Series, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self:
|
||||
native = pl.Series(data if is_numpy_array_1d(data) else [data])
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
def to_narwhals(self) -> Series[pl.Series]:
|
||||
return self._version.series(self, level="full")
|
||||
|
||||
def _with_native(self, series: pl.Series) -> Self:
|
||||
return self.__class__(series, version=self._version)
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: pl.Series) -> Self: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: pl.DataFrame) -> PolarsDataFrame: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: T) -> T: ...
|
||||
|
||||
def _from_native_object(
|
||||
self, series: pl.Series | pl.DataFrame | T
|
||||
) -> Self | PolarsDataFrame | T:
|
||||
if self._is_native(series):
|
||||
return self._with_native(series)
|
||||
if isinstance(series, pl.DataFrame):
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame.from_native(series, context=self)
|
||||
# scalar
|
||||
return series
|
||||
|
||||
def _to_expr(self) -> PolarsExpr:
|
||||
return self.__narwhals_namespace__()._expr._from_series(self)
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS:
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.native.name
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return native_to_narwhals_dtype(self.native.dtype, self._version)
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Series:
|
||||
return self._native_series
|
||||
|
||||
def alias(self, name: str) -> Self:
|
||||
return self._from_native_object(self.native.alias(name))
|
||||
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
|
||||
if isinstance(item, PolarsSeries):
|
||||
return self._from_native_object(self.native.__getitem__(item.native))
|
||||
return self._from_native_object(self.native.__getitem__(item))
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
|
||||
return self._with_native(self.native.cast(dtype_pl))
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self:
|
||||
ser = self.native
|
||||
dtype = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
return self._with_native(ser.replace_strict(old, new, return_dtype=dtype))
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
|
||||
return self.__array__(dtype, copy=copy)
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray:
|
||||
if self._backend_version < (0, 20, 29):
|
||||
return self.native.__array__(dtype=dtype)
|
||||
return self.native.__array__(dtype=dtype, copy=copy)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__eq__(extract_native(other)))
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__ne__(extract_native(other)))
|
||||
|
||||
# NOTE: `pyright` is being reasonable here
|
||||
def __ge__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__ge__(extract_native(other))) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def __gt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__gt__(extract_native(other))) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def __le__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__le__(extract_native(other))) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def __lt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__lt__(extract_native(other))) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def __rpow__(self, other: PolarsSeries | Any) -> Self:
|
||||
result = self.native.__rpow__(extract_native(other))
|
||||
if self._backend_version < (1, 16, 1):
|
||||
# Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071
|
||||
result = result.alias(self.name)
|
||||
return self._with_native(result)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
try:
|
||||
native_is_nan = self.native.is_nan()
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
if self._backend_version < (1, 18): # pragma: no cover
|
||||
select = pl.when(self.native.is_not_null()).then(native_is_nan)
|
||||
return self._with_native(pl.select(select)[self.name])
|
||||
return self._with_native(native_is_nan)
|
||||
|
||||
def median(self) -> Any:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if not self.dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self.native.median()
|
||||
|
||||
def to_dummies(self, *, separator: str, drop_first: bool) -> PolarsDataFrame:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
if self._backend_version < (0, 20, 15):
|
||||
has_nulls = self.native.is_null().any()
|
||||
result = self.native.to_dummies(separator=separator)
|
||||
output_columns = result.columns
|
||||
if drop_first:
|
||||
_ = output_columns.pop(int(has_nulls))
|
||||
|
||||
result = result.select(output_columns)
|
||||
else:
|
||||
result = self.native.to_dummies(separator=separator, drop_first=drop_first)
|
||||
result = result.with_columns(pl.all().cast(pl.Int8))
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
extra_kwargs = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
|
||||
native_result = self.native.ewm_mean(
|
||||
com=com,
|
||||
span=span,
|
||||
half_life=half_life,
|
||||
alpha=alpha,
|
||||
adjust=adjust,
|
||||
ignore_nulls=ignore_nulls,
|
||||
**extra_kwargs,
|
||||
)
|
||||
if self._backend_version < (1,): # pragma: no cover
|
||||
return self._with_native(
|
||||
pl.select(
|
||||
pl.when(~self.native.is_null()).then(native_result).otherwise(None)
|
||||
)[self.native.name]
|
||||
)
|
||||
|
||||
return self._with_native(native_result)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_var(
|
||||
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_std(
|
||||
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_sum(
|
||||
window_size=window_size, center=center, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_mean(
|
||||
window_size=window_size, center=center, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def sort(self, *, descending: bool, nulls_last: bool) -> Self:
|
||||
if self._backend_version < (0, 20, 6):
|
||||
result = self.native.sort(descending=descending)
|
||||
|
||||
if nulls_last:
|
||||
is_null = result.is_null()
|
||||
result = pl.concat([result.filter(~is_null), result.filter(is_null)])
|
||||
else:
|
||||
result = self.native.sort(descending=descending, nulls_last=nulls_last)
|
||||
|
||||
return self._with_native(result)
|
||||
|
||||
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
|
||||
s = self.native.clone().scatter(indices, extract_native(values))
|
||||
return self._with_native(s)
|
||||
|
||||
def value_counts(
|
||||
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
|
||||
) -> PolarsDataFrame:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
if self._backend_version < (1, 0, 0):
|
||||
value_name_ = name or ("proportion" if normalize else "count")
|
||||
|
||||
result = self.native.value_counts(sort=sort, parallel=parallel).select(
|
||||
**{
|
||||
(self.native.name): pl.col(self.native.name),
|
||||
value_name_: pl.col("count") / pl.sum("count")
|
||||
if normalize
|
||||
else pl.col("count"),
|
||||
}
|
||||
)
|
||||
else:
|
||||
result = self.native.value_counts(
|
||||
sort=sort, parallel=parallel, name=name, normalize=normalize
|
||||
)
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_native(self.native.cum_count(reverse=reverse))
|
||||
|
||||
def __contains__(self, other: Any) -> bool:
|
||||
try:
|
||||
return self.native.__contains__(other)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
|
||||
True: ["breakpoint", "count"],
|
||||
False: ["count"],
|
||||
}
|
||||
|
||||
def hist_from_bins(
|
||||
self, bins: list[float], *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
if len(bins) <= 1:
|
||||
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
|
||||
elif self.native.is_empty():
|
||||
if include_breakpoint:
|
||||
native = (
|
||||
pl.Series(bins[1:])
|
||||
.to_frame("breakpoint")
|
||||
.with_columns(count=pl.lit(0, pl.Int64))
|
||||
)
|
||||
else:
|
||||
native = pl.select(count=pl.zeros(len(bins) - 1, pl.Int64))
|
||||
else:
|
||||
return self._hist_from_data(
|
||||
bins=bins, bin_count=None, include_breakpoint=include_breakpoint
|
||||
)
|
||||
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
|
||||
|
||||
def hist_from_bin_count(
|
||||
self, bin_count: int, *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
if bin_count == 0:
|
||||
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
|
||||
elif self.native.is_empty():
|
||||
if include_breakpoint:
|
||||
native = pl.select(
|
||||
breakpoint=pl.int_range(1, bin_count + 1) / bin_count,
|
||||
count=pl.lit(0, pl.Int64),
|
||||
)
|
||||
else:
|
||||
native = pl.select(count=pl.zeros(bin_count, pl.Int64))
|
||||
else:
|
||||
count: int | None
|
||||
if BACKEND_VERSION < (1, 15): # pragma: no cover
|
||||
count = None
|
||||
bins = self._bins_from_bin_count(bin_count=bin_count)
|
||||
else:
|
||||
count = bin_count
|
||||
bins = None
|
||||
return self._hist_from_data(
|
||||
bins=bins, # type: ignore[arg-type]
|
||||
bin_count=count,
|
||||
include_breakpoint=include_breakpoint,
|
||||
)
|
||||
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
|
||||
|
||||
def _bins_from_bin_count(self, bin_count: int) -> pl.Series: # pragma: no cover
|
||||
"""Prepare bins based on backend version compatibility.
|
||||
|
||||
polars <1.15 does not adjust the bins when they have equivalent min/max
|
||||
polars <1.5 with bin_count=...
|
||||
returns bins that range from -inf to +inf and has bin_count + 1 bins.
|
||||
for compat: convert `bin_count=` call to `bins=`
|
||||
"""
|
||||
from typing import cast
|
||||
|
||||
lower = cast("float", self.native.min())
|
||||
upper = cast("float", self.native.max())
|
||||
|
||||
if lower == upper:
|
||||
lower -= 0.5
|
||||
upper += 0.5
|
||||
|
||||
width = (upper - lower) / bin_count
|
||||
return pl.int_range(0, bin_count + 1, eager=True) * width + lower
|
||||
|
||||
def _hist_from_data(
|
||||
self, bins: list[float] | None, bin_count: int | None, *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
"""Calculate histogram from non-empty data and post-process the results based on the backend version."""
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
series = self.native
|
||||
|
||||
# Polars inconsistently handles NaN values when computing histograms
|
||||
# against predefined bins: https://github.com/pola-rs/polars/issues/21082
|
||||
if BACKEND_VERSION < (1, 15) or bins is not None:
|
||||
series = series.fill_nan(None)
|
||||
|
||||
df = series.hist(
|
||||
bins,
|
||||
bin_count=bin_count,
|
||||
include_category=False,
|
||||
include_breakpoint=include_breakpoint,
|
||||
)
|
||||
|
||||
# Apply post-processing corrections
|
||||
|
||||
# Handle column naming
|
||||
if not include_breakpoint:
|
||||
col_name = df.columns[0]
|
||||
df = df.select(pl.col(col_name).alias("count"))
|
||||
elif BACKEND_VERSION < (1, 0): # pragma: no cover
|
||||
df = df.rename({"break_point": "breakpoint"})
|
||||
|
||||
if bins is not None: # pragma: no cover
|
||||
# polars<1.6 implicitly adds -inf and inf to either end of bins
|
||||
if BACKEND_VERSION < (1, 6):
|
||||
r = pl.int_range(0, len(df))
|
||||
df = df.filter((r > 0) & (r < len(df) - 1))
|
||||
# polars<1.27 makes the lowest bin a left/right closed interval
|
||||
if BACKEND_VERSION < (1, 27):
|
||||
df = (
|
||||
df.slice(0, 1)
|
||||
.with_columns(pl.col("count") + ((pl.lit(series) == bins[0]).sum()))
|
||||
.vstack(df.slice(1))
|
||||
)
|
||||
|
||||
return PolarsDataFrame.from_native(df, context=self)
|
||||
|
||||
def to_polars(self) -> pl.Series:
|
||||
return self.native
|
||||
|
||||
@property
|
||||
def dt(self) -> PolarsSeriesDateTimeNamespace:
|
||||
return PolarsSeriesDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def str(self) -> PolarsSeriesStringNamespace:
|
||||
return PolarsSeriesStringNamespace(self)
|
||||
|
||||
@property
|
||||
def cat(self) -> PolarsSeriesCatNamespace:
|
||||
return PolarsSeriesCatNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> PolarsSeriesStructNamespace:
|
||||
return PolarsSeriesStructNamespace(self)
|
||||
|
||||
__add__: Method[Self]
|
||||
__and__: Method[Self]
|
||||
__floordiv__: Method[Self]
|
||||
__invert__: Method[Self]
|
||||
__iter__: Method[Iterator[Any]]
|
||||
__mod__: Method[Self]
|
||||
__mul__: Method[Self]
|
||||
__or__: Method[Self]
|
||||
__pow__: Method[Self]
|
||||
__radd__: Method[Self]
|
||||
__rand__: Method[Self]
|
||||
__rfloordiv__: Method[Self]
|
||||
__rmod__: Method[Self]
|
||||
__rmul__: Method[Self]
|
||||
__ror__: Method[Self]
|
||||
__rsub__: Method[Self]
|
||||
__rtruediv__: Method[Self]
|
||||
__sub__: Method[Self]
|
||||
__truediv__: Method[Self]
|
||||
abs: Method[Self]
|
||||
all: Method[bool]
|
||||
any: Method[bool]
|
||||
arg_max: Method[int]
|
||||
arg_min: Method[int]
|
||||
arg_true: Method[Self]
|
||||
clip: Method[Self]
|
||||
count: Method[int]
|
||||
cum_max: Method[Self]
|
||||
cum_min: Method[Self]
|
||||
cum_prod: Method[Self]
|
||||
cum_sum: Method[Self]
|
||||
diff: Method[Self]
|
||||
drop_nulls: Method[Self]
|
||||
exp: Method[Self]
|
||||
fill_null: Method[Self]
|
||||
filter: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
is_between: Method[Self]
|
||||
is_finite: Method[Self]
|
||||
is_first_distinct: Method[Self]
|
||||
is_in: Method[Self]
|
||||
is_last_distinct: Method[Self]
|
||||
is_null: Method[Self]
|
||||
is_sorted: Method[bool]
|
||||
is_unique: Method[Self]
|
||||
item: Method[Any]
|
||||
kurtosis: Method[float | None]
|
||||
len: Method[int]
|
||||
log: Method[Self]
|
||||
max: Method[Any]
|
||||
mean: Method[float]
|
||||
min: Method[Any]
|
||||
mode: Method[Self]
|
||||
n_unique: Method[int]
|
||||
null_count: Method[int]
|
||||
quantile: Method[float]
|
||||
rank: Method[Self]
|
||||
round: Method[Self]
|
||||
sample: Method[Self]
|
||||
shift: Method[Self]
|
||||
skew: Method[float | None]
|
||||
sqrt: Method[Self]
|
||||
std: Method[float]
|
||||
sum: Method[float]
|
||||
tail: Method[Self]
|
||||
to_arrow: Method[pa.Array[Any]]
|
||||
to_frame: Method[PolarsDataFrame]
|
||||
to_list: Method[list[Any]]
|
||||
to_pandas: Method[pd.Series[Any]]
|
||||
unique: Method[Self]
|
||||
var: Method[float]
|
||||
zip_with: Method[Self]
|
||||
|
||||
@property
|
||||
def list(self) -> PolarsSeriesListNamespace:
|
||||
return PolarsSeriesListNamespace(self)
|
||||
|
||||
|
||||
class PolarsSeriesDateTimeNamespace:
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._compliant_series._with_native(
|
||||
getattr(self._compliant_series.native.dt, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsSeriesStringNamespace:
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
def zfill(self, width: int) -> PolarsSeries:
|
||||
series = self._compliant_series
|
||||
name = series.name
|
||||
ns = series.__narwhals_namespace__()
|
||||
return series.to_frame().select(ns.col(name).str.zfill(width)).get_column(name)
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._compliant_series._with_native(
|
||||
getattr(self._compliant_series.native.str, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsSeriesCatNamespace:
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._compliant_series._with_native(
|
||||
getattr(self._compliant_series.native.cat, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsSeriesListNamespace:
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._series = series
|
||||
|
||||
def len(self) -> PolarsSeries:
|
||||
native_series = self._series.native
|
||||
native_result = native_series.list.len()
|
||||
|
||||
if self._series._backend_version < (1, 16): # pragma: no cover
|
||||
native_result = pl.select(
|
||||
pl.when(~native_series.is_null()).then(native_result).otherwise(None)
|
||||
)[native_series.name].cast(pl.UInt32())
|
||||
|
||||
elif self._series._backend_version < (1, 17): # pragma: no cover
|
||||
native_result = native_series.cast(pl.UInt32())
|
||||
|
||||
return self._series._with_native(native_result)
|
||||
|
||||
# TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added
|
||||
def __getattr__(self, attr: str) -> Any: # pragma: no cover
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._series._with_native(
|
||||
getattr(self._series.native.list, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsSeriesStructNamespace:
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._compliant_series._with_native(
|
||||
getattr(self._compliant_series.native.struct, attr)(*pos, **kwds)
|
||||
)
|
||||
|
||||
return func
|
||||
22
venv/lib/python3.11/site-packages/narwhals/_polars/typing.py
Normal file
22
venv/lib/python3.11/site-packages/narwhals/_polars/typing.py
Normal file
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING, # pragma: no cover
|
||||
Union, # pragma: no cover
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
from typing import TypeVar
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries]
|
||||
FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame)
|
||||
242
venv/lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
242
venv/lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
_DeferredIterable,
|
||||
isinstance_or_issubclass,
|
||||
)
|
||||
from narwhals.exceptions import (
|
||||
ColumnNotFoundError,
|
||||
ComputeError,
|
||||
DuplicateError,
|
||||
InvalidOperationError,
|
||||
NarwhalsError,
|
||||
ShapeError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from narwhals._utils import _StoresNative
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
T = TypeVar("T")
|
||||
NativeT = TypeVar(
|
||||
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
|
||||
)
|
||||
|
||||
BACKEND_VERSION = Implementation.POLARS._backend_version()
|
||||
"""Static backend version for `polars`."""
|
||||
|
||||
|
||||
@overload
|
||||
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
|
||||
@overload
|
||||
def extract_native(obj: T) -> T: ...
|
||||
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
|
||||
return obj.native if _is_compliant_polars(obj) else obj
|
||||
|
||||
|
||||
def _is_compliant_polars(
|
||||
obj: _StoresNative[NativeT] | Any,
|
||||
) -> TypeIs[_StoresNative[NativeT]]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
|
||||
|
||||
|
||||
def extract_args_kwargs(
|
||||
args: Iterable[Any], kwds: Mapping[str, Any], /
|
||||
) -> tuple[Iterator[Any], dict[str, Any]]:
|
||||
it_args = (extract_native(arg) for arg in args)
|
||||
return it_args, {k: extract_native(v) for k, v in kwds.items()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype( # noqa: C901, PLR0912
|
||||
dtype: pl.DataType, version: Version
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if dtype == pl.Float64:
|
||||
return dtypes.Float64()
|
||||
if dtype == pl.Float32:
|
||||
return dtypes.Float32()
|
||||
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.Int128()
|
||||
if dtype == pl.Int64:
|
||||
return dtypes.Int64()
|
||||
if dtype == pl.Int32:
|
||||
return dtypes.Int32()
|
||||
if dtype == pl.Int16:
|
||||
return dtypes.Int16()
|
||||
if dtype == pl.Int8:
|
||||
return dtypes.Int8()
|
||||
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.UInt128()
|
||||
if dtype == pl.UInt64:
|
||||
return dtypes.UInt64()
|
||||
if dtype == pl.UInt32:
|
||||
return dtypes.UInt32()
|
||||
if dtype == pl.UInt16:
|
||||
return dtypes.UInt16()
|
||||
if dtype == pl.UInt8:
|
||||
return dtypes.UInt8()
|
||||
if dtype == pl.String:
|
||||
return dtypes.String()
|
||||
if dtype == pl.Boolean:
|
||||
return dtypes.Boolean()
|
||||
if dtype == pl.Object:
|
||||
return dtypes.Object()
|
||||
if dtype == pl.Categorical:
|
||||
return dtypes.Categorical()
|
||||
if isinstance_or_issubclass(dtype, pl.Enum):
|
||||
if version is Version.V1:
|
||||
return dtypes.Enum() # type: ignore[call-arg]
|
||||
categories = _DeferredIterable(dtype.categories.to_list)
|
||||
return dtypes.Enum(categories)
|
||||
if dtype == pl.Date:
|
||||
return dtypes.Date()
|
||||
if isinstance_or_issubclass(dtype, pl.Datetime):
|
||||
return (
|
||||
dtypes.Datetime()
|
||||
if dtype is pl.Datetime
|
||||
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Duration):
|
||||
return (
|
||||
dtypes.Duration()
|
||||
if dtype is pl.Duration
|
||||
else dtypes.Duration(dtype.time_unit)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Struct):
|
||||
fields = [
|
||||
dtypes.Field(name, native_to_narwhals_dtype(tp, version))
|
||||
for name, tp in dtype
|
||||
]
|
||||
return dtypes.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, pl.List):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, pl.Array):
|
||||
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
|
||||
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
|
||||
if dtype == pl.Decimal:
|
||||
return dtypes.Decimal()
|
||||
if dtype == pl.Time:
|
||||
return dtypes.Time()
|
||||
if dtype == pl.Binary:
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown()
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: C901, PLR0912
|
||||
dtype: IntoDType, version: Version
|
||||
) -> pl.DataType:
|
||||
dtypes = version.dtypes
|
||||
if dtype == dtypes.Float64:
|
||||
return pl.Float64()
|
||||
if dtype == dtypes.Float32:
|
||||
return pl.Float32()
|
||||
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
|
||||
# Not available for Polars pre 1.8.0
|
||||
return pl.Int128()
|
||||
if dtype == dtypes.Int64:
|
||||
return pl.Int64()
|
||||
if dtype == dtypes.Int32:
|
||||
return pl.Int32()
|
||||
if dtype == dtypes.Int16:
|
||||
return pl.Int16()
|
||||
if dtype == dtypes.Int8:
|
||||
return pl.Int8()
|
||||
if dtype == dtypes.UInt64:
|
||||
return pl.UInt64()
|
||||
if dtype == dtypes.UInt32:
|
||||
return pl.UInt32()
|
||||
if dtype == dtypes.UInt16:
|
||||
return pl.UInt16()
|
||||
if dtype == dtypes.UInt8:
|
||||
return pl.UInt8()
|
||||
if dtype == dtypes.String:
|
||||
return pl.String()
|
||||
if dtype == dtypes.Boolean:
|
||||
return pl.Boolean()
|
||||
if dtype == dtypes.Object: # pragma: no cover
|
||||
return pl.Object()
|
||||
if dtype == dtypes.Categorical:
|
||||
return pl.Categorical()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
return pl.Enum(dtype.categories)
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if dtype == dtypes.Date:
|
||||
return pl.Date()
|
||||
if dtype == dtypes.Time:
|
||||
return pl.Time()
|
||||
if dtype == dtypes.Binary:
|
||||
return pl.Binary()
|
||||
if dtype == dtypes.Decimal:
|
||||
msg = "Casting to Decimal is not supported yet."
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pl.List(narwhals_to_native_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = [
|
||||
pl.Field(field.name, narwhals_to_native_dtype(field.dtype, version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
return pl.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
size = dtype.size
|
||||
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
|
||||
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
|
||||
return pl.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def _is_polars_exception(exception: Exception) -> bool:
|
||||
if BACKEND_VERSION >= (1,):
|
||||
# Old versions of Polars didn't have PolarsError.
|
||||
return isinstance(exception, pl.exceptions.PolarsError)
|
||||
# Last attempt, for old Polars versions.
|
||||
return "polars.exceptions" in str(type(exception)) # pragma: no cover
|
||||
|
||||
|
||||
def _is_cudf_exception(exception: Exception) -> bool:
|
||||
# These exceptions are raised when running polars on GPUs via cuDF
|
||||
return str(exception).startswith("CUDF failure")
|
||||
|
||||
|
||||
def catch_polars_exception(exception: Exception) -> NarwhalsError | Exception:
|
||||
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
|
||||
return ColumnNotFoundError(str(exception))
|
||||
elif isinstance(exception, pl.exceptions.ShapeError):
|
||||
return ShapeError(str(exception))
|
||||
elif isinstance(exception, pl.exceptions.InvalidOperationError):
|
||||
return InvalidOperationError(str(exception))
|
||||
elif isinstance(exception, pl.exceptions.DuplicateError):
|
||||
return DuplicateError(str(exception))
|
||||
elif isinstance(exception, pl.exceptions.ComputeError):
|
||||
return ComputeError(str(exception))
|
||||
if _is_polars_exception(exception) or _is_cudf_exception(exception):
|
||||
return NarwhalsError(str(exception)) # pragma: no cover
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
Reference in New Issue
Block a user