some new features
This commit is contained in:
186
.venv/lib/python3.12/site-packages/narwhals/__init__.py
Normal file
186
.venv/lib/python3.12/site-packages/narwhals/__init__.py
Normal file
@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as _t
|
||||
|
||||
from narwhals import dependencies, dtypes, exceptions, selectors
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
is_ordered_categorical,
|
||||
maybe_align_index,
|
||||
maybe_convert_dtypes,
|
||||
maybe_get_index,
|
||||
maybe_reset_index,
|
||||
maybe_set_index,
|
||||
)
|
||||
from narwhals.dataframe import DataFrame, LazyFrame
|
||||
from narwhals.dtypes import (
|
||||
Array,
|
||||
Binary,
|
||||
Boolean,
|
||||
Categorical,
|
||||
Date,
|
||||
Datetime,
|
||||
Decimal,
|
||||
Duration,
|
||||
Enum,
|
||||
Field,
|
||||
Float32,
|
||||
Float64,
|
||||
Int8,
|
||||
Int16,
|
||||
Int32,
|
||||
Int64,
|
||||
Int128,
|
||||
List,
|
||||
Object,
|
||||
String,
|
||||
Struct,
|
||||
Time,
|
||||
UInt8,
|
||||
UInt16,
|
||||
UInt32,
|
||||
UInt64,
|
||||
UInt128,
|
||||
Unknown,
|
||||
)
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.functions import (
|
||||
all_ as all,
|
||||
all_horizontal,
|
||||
any_horizontal,
|
||||
coalesce,
|
||||
col,
|
||||
concat,
|
||||
concat_str,
|
||||
exclude,
|
||||
from_arrow,
|
||||
from_dict,
|
||||
from_numpy,
|
||||
len_ as len,
|
||||
lit,
|
||||
max,
|
||||
max_horizontal,
|
||||
mean,
|
||||
mean_horizontal,
|
||||
median,
|
||||
min,
|
||||
min_horizontal,
|
||||
new_series,
|
||||
nth,
|
||||
read_csv,
|
||||
read_parquet,
|
||||
scan_csv,
|
||||
scan_parquet,
|
||||
show_versions,
|
||||
sum,
|
||||
sum_horizontal,
|
||||
when,
|
||||
)
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.series import Series
|
||||
from narwhals.translate import (
|
||||
from_native,
|
||||
get_native_namespace,
|
||||
narwhalify,
|
||||
to_native,
|
||||
to_py_scalar,
|
||||
)
|
||||
|
||||
__version__: str
|
||||
|
||||
__all__ = [
|
||||
"Array",
|
||||
"Binary",
|
||||
"Boolean",
|
||||
"Categorical",
|
||||
"DataFrame",
|
||||
"Date",
|
||||
"Datetime",
|
||||
"Decimal",
|
||||
"Duration",
|
||||
"Enum",
|
||||
"Expr",
|
||||
"Field",
|
||||
"Float32",
|
||||
"Float64",
|
||||
"Implementation",
|
||||
"Int8",
|
||||
"Int16",
|
||||
"Int32",
|
||||
"Int64",
|
||||
"Int128",
|
||||
"LazyFrame",
|
||||
"List",
|
||||
"Object",
|
||||
"Schema",
|
||||
"Series",
|
||||
"String",
|
||||
"Struct",
|
||||
"Time",
|
||||
"UInt8",
|
||||
"UInt16",
|
||||
"UInt32",
|
||||
"UInt64",
|
||||
"UInt128",
|
||||
"Unknown",
|
||||
"all",
|
||||
"all_horizontal",
|
||||
"any_horizontal",
|
||||
"coalesce",
|
||||
"col",
|
||||
"concat",
|
||||
"concat_str",
|
||||
"dependencies",
|
||||
"dtypes",
|
||||
"exceptions",
|
||||
"exclude",
|
||||
"from_arrow",
|
||||
"from_dict",
|
||||
"from_native",
|
||||
"from_numpy",
|
||||
"generate_temporary_column_name",
|
||||
"get_native_namespace",
|
||||
"is_ordered_categorical",
|
||||
"len",
|
||||
"lit",
|
||||
"max",
|
||||
"max_horizontal",
|
||||
"maybe_align_index",
|
||||
"maybe_convert_dtypes",
|
||||
"maybe_get_index",
|
||||
"maybe_reset_index",
|
||||
"maybe_set_index",
|
||||
"mean",
|
||||
"mean_horizontal",
|
||||
"median",
|
||||
"min",
|
||||
"min_horizontal",
|
||||
"narwhalify",
|
||||
"new_series",
|
||||
"nth",
|
||||
"read_csv",
|
||||
"read_parquet",
|
||||
"scan_csv",
|
||||
"scan_parquet",
|
||||
"selectors",
|
||||
"show_versions",
|
||||
"sum",
|
||||
"sum_horizontal",
|
||||
"to_native",
|
||||
"to_py_scalar",
|
||||
"when",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: _t.Literal["__version__"]) -> str: # type: ignore[misc]
|
||||
if name == "__version__":
|
||||
global __version__ # noqa: PLW0603
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
__version__ = metadata.version(__name__)
|
||||
return __version__
|
||||
else:
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
761
.venv/lib/python3.12/site-packages/narwhals/_arrow/dataframe.py
Normal file
761
.venv/lib/python3.12/site-packages/narwhals/_arrow/dataframe.py
Normal file
@ -0,0 +1,761 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Collection, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.utils import native_to_narwhals_dtype
|
||||
from narwhals._compliant import EagerDataFrame
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
check_column_names_are_unique,
|
||||
convert_str_slice_to_int_slice,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
scale_bytes,
|
||||
supports_arrow_c_stream,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
from narwhals.exceptions import ShapeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
ChunkedArrayAny,
|
||||
Mask,
|
||||
Order,
|
||||
)
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.typing import (
|
||||
JoinStrategy,
|
||||
SizedMultiIndexSelector,
|
||||
SizedMultiNameSelector,
|
||||
SizeUnit,
|
||||
UniqueKeepStrategy,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
_SliceIndex,
|
||||
_SliceName,
|
||||
)
|
||||
|
||||
JoinType: TypeAlias = Literal[
|
||||
"left semi",
|
||||
"right semi",
|
||||
"left anti",
|
||||
"right anti",
|
||||
"inner",
|
||||
"left outer",
|
||||
"right outer",
|
||||
"full outer",
|
||||
]
|
||||
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
|
||||
|
||||
|
||||
class ArrowDataFrame(
|
||||
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
|
||||
):
|
||||
_implementation = Implementation.PYARROW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: pa.Table,
|
||||
*,
|
||||
version: Version,
|
||||
validate_column_names: bool,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
if validate_column_names:
|
||||
check_column_names_are_unique(native_dataframe.column_names)
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
self._native_frame = native_dataframe
|
||||
self._version = version
|
||||
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
|
||||
backend_version = context._implementation._backend_version()
|
||||
if cls._is_native(data):
|
||||
native = data
|
||||
elif backend_version >= (14,) or isinstance(data, Collection):
|
||||
native = pa.table(data)
|
||||
elif supports_arrow_c_stream(data): # pragma: no cover
|
||||
msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}."
|
||||
raise ModuleNotFoundError(msg)
|
||||
else: # pragma: no cover
|
||||
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
|
||||
raise TypeError(msg)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: Mapping[str, DType] | Schema | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
|
||||
if pa_schema and not data:
|
||||
native = pa_schema.empty_table()
|
||||
else:
|
||||
native = pa.Table.from_pydict(data, schema=pa_schema)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
|
||||
return isinstance(obj, pa.Table)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: pa.Table, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version, validate_column_names=True)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
arrays = [pa.array(val) for val in data.T]
|
||||
if isinstance(schema, (Mapping, Schema)):
|
||||
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
|
||||
else:
|
||||
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
def __narwhals_namespace__(self) -> ArrowNamespace:
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
return ArrowNamespace(version=self._version)
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.PYARROW:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version, validate_column_names=False)
|
||||
|
||||
def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self:
|
||||
return self.__class__(
|
||||
df, version=self._version, validate_column_names=validate_column_names
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
return self.native.shape
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def row(self, index: int) -> tuple[Any, ...]:
|
||||
return tuple(col[index] for col in self.native.itercolumns())
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
|
||||
|
||||
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
|
||||
if not named:
|
||||
return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value]
|
||||
return self.native.to_pylist()
|
||||
|
||||
def iter_columns(self) -> Iterator[ArrowSeries]:
|
||||
for name, series in zip(self.columns, self.native.itercolumns()):
|
||||
yield ArrowSeries.from_native(series, context=self, name=name)
|
||||
|
||||
_iter_columns = iter_columns
|
||||
|
||||
def iter_rows(
|
||||
self, *, named: bool, buffer_size: int
|
||||
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
|
||||
df = self.native
|
||||
num_rows = df.num_rows
|
||||
|
||||
if not named:
|
||||
for i in range(0, num_rows, buffer_size):
|
||||
rows = df[i : i + buffer_size].to_pydict().values()
|
||||
yield from zip(*rows)
|
||||
else:
|
||||
for i in range(0, num_rows, buffer_size):
|
||||
yield from df[i : i + buffer_size].to_pylist()
|
||||
|
||||
def get_column(self, name: str) -> ArrowSeries:
|
||||
if not isinstance(name, str):
|
||||
msg = f"Expected str, got: {type(name)}"
|
||||
raise TypeError(msg)
|
||||
return ArrowSeries.from_native(self.native[name], context=self, name=name)
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray:
|
||||
return self.native.__array__(dtype, copy=copy)
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
|
||||
if len(rows) == 0:
|
||||
return self._with_native(self.native.slice(0, 0))
|
||||
if self._backend_version < (18,) and isinstance(rows, tuple):
|
||||
rows = list(rows)
|
||||
return self._with_native(self.native.take(rows))
|
||||
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
|
||||
start = rows.start or 0
|
||||
stop = rows.stop if rows.stop is not None else len(self.native)
|
||||
if start < 0:
|
||||
start = len(self.native) + start
|
||||
if stop < 0:
|
||||
stop = len(self.native) + stop
|
||||
if rows.step is not None and rows.step != 1:
|
||||
msg = "Slicing with step is not supported on PyArrow tables"
|
||||
raise NotImplementedError(msg)
|
||||
return self._with_native(self.native.slice(start, stop - start))
|
||||
|
||||
def _select_slice_name(self, columns: _SliceName) -> Self:
|
||||
start, stop, step = convert_str_slice_to_int_slice(columns, self.columns)
|
||||
return self._with_native(self.native.select(self.columns[start:stop:step]))
|
||||
|
||||
def _select_slice_index(self, columns: _SliceIndex | range) -> Self:
|
||||
return self._with_native(
|
||||
self.native.select(self.columns[columns.start : columns.stop : columns.step])
|
||||
)
|
||||
|
||||
def _select_multi_index(
|
||||
self, columns: SizedMultiIndexSelector[ChunkedArrayAny]
|
||||
) -> Self:
|
||||
selector: Sequence[int]
|
||||
if isinstance(columns, pa.ChunkedArray):
|
||||
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
|
||||
selector = cast("Sequence[int]", columns.to_pylist())
|
||||
# TODO @dangotbanned: Fix upstream, it is actually much narrower
|
||||
# **Doesn't accept `ndarray`**
|
||||
elif is_numpy_array_1d(columns):
|
||||
selector = columns.tolist()
|
||||
else:
|
||||
selector = columns
|
||||
return self._with_native(self.native.select(selector))
|
||||
|
||||
def _select_multi_name(
|
||||
self, columns: SizedMultiNameSelector[ChunkedArrayAny]
|
||||
) -> Self:
|
||||
selector: Sequence[str] | _1DArray
|
||||
if isinstance(columns, pa.ChunkedArray):
|
||||
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
|
||||
selector = cast("Sequence[str]", columns.to_pylist())
|
||||
else:
|
||||
selector = columns
|
||||
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
|
||||
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
schema = self.native.schema
|
||||
return {
|
||||
name: native_to_narwhals_dtype(dtype, self._version)
|
||||
for name, dtype in zip(schema.names, schema.types)
|
||||
}
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def estimated_size(self, unit: SizeUnit) -> int | float:
|
||||
sz = self.native.nbytes
|
||||
return scale_bytes(sz, unit)
|
||||
|
||||
explode = not_implemented()
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return self.native.column_names
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(
|
||||
self.native.select(list(column_names)), validate_column_names=False
|
||||
)
|
||||
|
||||
def select(self, *exprs: ArrowExpr) -> Self:
|
||||
new_series = self._evaluate_into_exprs(*exprs)
|
||||
if not new_series:
|
||||
# return empty dataframe, like Polars does
|
||||
return self._with_native(
|
||||
self.native.__class__.from_arrays([]), validate_column_names=False
|
||||
)
|
||||
names = [s.name for s in new_series]
|
||||
align = new_series[0]._align_full_broadcast
|
||||
reshaped = align(*new_series)
|
||||
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
|
||||
return self._with_native(df, validate_column_names=True)
|
||||
|
||||
def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny:
|
||||
length = len(self)
|
||||
if not other._broadcast:
|
||||
if (len_other := len(other)) != length:
|
||||
msg = f"Expected object of length {length}, got: {len_other}."
|
||||
raise ShapeError(msg)
|
||||
return other.native
|
||||
|
||||
value = other.native[0]
|
||||
return pa.chunked_array([pa.repeat(value, length)])
|
||||
|
||||
def with_columns(self, *exprs: ArrowExpr) -> Self:
|
||||
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
|
||||
# All `pyarrow` data is immutable, so this is fine
|
||||
native_frame = self.native
|
||||
new_columns = self._evaluate_into_exprs(*exprs)
|
||||
columns = self.columns
|
||||
|
||||
for col_value in new_columns:
|
||||
col_name = col_value.name
|
||||
column = self._extract_comparand(col_value)
|
||||
native_frame = (
|
||||
native_frame.set_column(columns.index(col_name), col_name, column=column)
|
||||
if col_name in columns
|
||||
else native_frame.append_column(col_name, column=column)
|
||||
)
|
||||
|
||||
return self._with_native(native_frame, validate_column_names=False)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
|
||||
) -> ArrowGroupBy:
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
|
||||
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_to_join_map: dict[str, JoinType] = {
|
||||
"anti": "left anti",
|
||||
"semi": "left semi",
|
||||
"inner": "inner",
|
||||
"left": "left outer",
|
||||
"full": "full outer",
|
||||
}
|
||||
|
||||
if how == "cross":
|
||||
plx = self.__narwhals_namespace__()
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=[*self.columns, *other.columns]
|
||||
)
|
||||
|
||||
return self._with_native(
|
||||
self.with_columns(
|
||||
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
|
||||
)
|
||||
.native.join(
|
||||
other.with_columns(
|
||||
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
|
||||
).native,
|
||||
keys=key_token,
|
||||
right_keys=key_token,
|
||||
join_type="inner",
|
||||
right_suffix=suffix,
|
||||
)
|
||||
.drop([key_token])
|
||||
)
|
||||
|
||||
coalesce_keys = how != "full" # polars full join does not coalesce keys
|
||||
return self._with_native(
|
||||
self.native.join(
|
||||
other.native,
|
||||
keys=left_on or [], # type: ignore[arg-type]
|
||||
right_keys=right_on, # type: ignore[arg-type]
|
||||
join_type=how_to_join_map[how],
|
||||
right_suffix=suffix,
|
||||
coalesce_keys=coalesce_keys,
|
||||
)
|
||||
)
|
||||
|
||||
join_asof = not_implemented()
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(to_drop), validate_column_names=False)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.drop_null(), validate_column_names=False)
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
order: Order = "descending" if descending else "ascending"
|
||||
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
|
||||
else:
|
||||
sorting = [
|
||||
(key, "descending" if is_descending else "ascending")
|
||||
for key, is_descending in zip(by, descending)
|
||||
]
|
||||
|
||||
null_placement = "at_end" if nulls_last else "at_start"
|
||||
|
||||
return self._with_native(
|
||||
self.native.sort_by(sorting, null_placement=null_placement),
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
return self.native.to_pandas()
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
return pl.from_arrow(self.native) # type: ignore[return-value]
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])
|
||||
return arr
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
|
||||
it = self.iter_columns()
|
||||
if as_series:
|
||||
return {ser.name: ser for ser in it}
|
||||
return {ser.name: ser.to_list() for ser in it}
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
plx = self.__narwhals_namespace__()
|
||||
if order_by is None:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
data = pa.array(np.arange(len(self), dtype=np.int64))
|
||||
row_index = plx._expr._from_series(
|
||||
plx._series.from_iterable(data, context=self, name=name)
|
||||
)
|
||||
else:
|
||||
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
|
||||
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
|
||||
return self.select(row_index, plx.all())
|
||||
|
||||
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
|
||||
if isinstance(predicate, list):
|
||||
mask_native: Mask | ChunkedArrayAny = predicate
|
||||
else:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask_native = self._evaluate_into_exprs(predicate)[0].native
|
||||
return self._with_native(
|
||||
self.native.filter(mask_native), validate_column_names=False
|
||||
)
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
df = self.native
|
||||
if n >= 0:
|
||||
return self._with_native(df.slice(0, n), validate_column_names=False)
|
||||
else:
|
||||
num_rows = df.num_rows
|
||||
return self._with_native(
|
||||
df.slice(0, max(0, num_rows + n)), validate_column_names=False
|
||||
)
|
||||
|
||||
def tail(self, n: int) -> Self:
|
||||
df = self.native
|
||||
if n >= 0:
|
||||
num_rows = df.num_rows
|
||||
return self._with_native(
|
||||
df.slice(max(0, num_rows - n)), validate_column_names=False
|
||||
)
|
||||
else:
|
||||
return self._with_native(df.slice(abs(n)), validate_column_names=False)
|
||||
|
||||
def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny:
|
||||
if backend is None:
|
||||
return self
|
||||
elif backend is Implementation.DUCKDB:
|
||||
import duckdb # ignore-banned-import
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
|
||||
df = self.native # noqa: F841
|
||||
return DuckDBLazyFrame(
|
||||
duckdb.table("df"), validate_backend_version=True, version=self._version
|
||||
)
|
||||
elif backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsLazyFrame
|
||||
|
||||
return PolarsLazyFrame(
|
||||
cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
elif backend is Implementation.DASK:
|
||||
import dask.dataframe as dd # ignore-banned-import
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
return DaskLazyFrame(
|
||||
dd.from_pandas(self.native.to_pandas()),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
elif backend.is_ibis():
|
||||
import ibis # ignore-banned-import
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
|
||||
return IbisLazyFrame(
|
||||
ibis.memtable(self.native, columns=self.columns),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
def collect(
|
||||
self, backend: Implementation | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is Implementation.PYARROW or backend is None:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self.native, version=self._version, validate_column_names=False
|
||||
)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
cast("pl.DataFrame", pl.from_arrow(self.native)),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise AssertionError(msg) # pragma: no cover
|
||||
|
||||
def clone(self) -> Self:
|
||||
return self._with_native(self.native, validate_column_names=False)
|
||||
|
||||
def item(self, row: int | None, column: int | str | None) -> Any:
|
||||
from narwhals._arrow.series import maybe_extract_py_scalar
|
||||
|
||||
if row is None and column is None:
|
||||
if self.shape != (1, 1):
|
||||
msg = (
|
||||
"can only call `.item()` if the dataframe is of shape (1, 1),"
|
||||
" or if explicit row/col values are provided;"
|
||||
f" frame has shape {self.shape!r}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True)
|
||||
|
||||
elif row is None or column is None:
|
||||
msg = "cannot call `.item()` with only one of `row` or `column`"
|
||||
raise ValueError(msg)
|
||||
|
||||
_col = self.columns.index(column) if isinstance(column, str) else column
|
||||
return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True)
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
names: dict[str, str] | list[str]
|
||||
if self._backend_version >= (17,):
|
||||
names = cast("dict[str, str]", mapping)
|
||||
else: # pragma: no cover
|
||||
names = [mapping.get(c, c) for c in self.columns]
|
||||
return self._with_native(self.native.rename_columns(names))
|
||||
|
||||
def write_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
import pyarrow.parquet as pp
|
||||
|
||||
pp.write_table(self.native, file)
|
||||
|
||||
@overload
|
||||
def write_csv(self, file: None) -> str: ...
|
||||
|
||||
@overload
|
||||
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
def write_csv(self, file: str | Path | BytesIO | None) -> str | None:
|
||||
import pyarrow.csv as pa_csv
|
||||
|
||||
if file is None:
|
||||
csv_buffer = pa.BufferOutputStream()
|
||||
pa_csv.write_csv(self.native, csv_buffer)
|
||||
return csv_buffer.getvalue().to_pybytes().decode()
|
||||
pa_csv.write_csv(self.native, file)
|
||||
return None
|
||||
|
||||
def is_unique(self) -> ArrowSeries:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
row_index = pa.array(np.arange(len(self)))
|
||||
keep_idx = (
|
||||
self.native.append_column(col_token, row_index)
|
||||
.group_by(self.columns)
|
||||
.aggregate([(col_token, "min"), (col_token, "max")])
|
||||
)
|
||||
native = pa.chunked_array(
|
||||
pc.and_(
|
||||
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
|
||||
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
|
||||
)
|
||||
)
|
||||
return ArrowSeries.from_native(native, context=self)
|
||||
|
||||
def unique(
|
||||
self,
|
||||
subset: Sequence[str] | None,
|
||||
*,
|
||||
keep: UniqueKeepStrategy,
|
||||
maintain_order: bool | None = None,
|
||||
) -> Self:
|
||||
# The param `maintain_order` is only here for compatibility with the Polars API
|
||||
# and has no effect on the output.
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
subset = list(subset or self.columns)
|
||||
|
||||
if keep in {"any", "first", "last"}:
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
|
||||
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
keep_idx_native = (
|
||||
self.native.append_column(col_token, pa.array(np.arange(len(self))))
|
||||
.group_by(subset)
|
||||
.aggregate([(col_token, agg_func)])
|
||||
.column(f"{col_token}_{agg_func}")
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.take(keep_idx_native), validate_column_names=False
|
||||
)
|
||||
|
||||
keep_idx = self.simple_select(*subset).is_unique()
|
||||
plx = self.__narwhals_namespace__()
|
||||
return self.filter(plx._expr._from_series(keep_idx))
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
return self._with_native(self.native[offset::n], validate_column_names=False)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.native
|
||||
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
num_rows = len(self)
|
||||
if n is None and fraction is not None:
|
||||
n = int(num_rows * fraction)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
idx = np.arange(num_rows)
|
||||
mask = rng.choice(idx, size=n, replace=with_replacement)
|
||||
return self._with_native(self.native.take(mask), validate_column_names=False)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
n_rows = len(self)
|
||||
index_ = [] if index is None else index
|
||||
on_ = [c for c in self.columns if c not in index_] if on is None else on
|
||||
concat = (
|
||||
partial(pa.concat_tables, promote_options="permissive")
|
||||
if self._backend_version >= (14, 0, 0)
|
||||
else pa.concat_tables
|
||||
)
|
||||
names = [*index_, variable_name, value_name]
|
||||
return self._with_native(
|
||||
concat(
|
||||
[
|
||||
pa.Table.from_arrays(
|
||||
[
|
||||
*(self.native.column(idx_col) for idx_col in index_),
|
||||
cast(
|
||||
"ChunkedArrayAny",
|
||||
pa.array([on_col] * n_rows, pa.string()),
|
||||
),
|
||||
self.native.column(on_col),
|
||||
],
|
||||
names=names,
|
||||
)
|
||||
for on_col in on_
|
||||
]
|
||||
)
|
||||
)
|
||||
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
|
||||
# upcast numeric to non-numeric (e.g. string) datatypes
|
||||
|
||||
pivot = not_implemented()
|
||||
172
.venv/lib/python3.12/site-packages/narwhals/_arrow/expr.py
Normal file
172
.venv/lib/python3.12/site-packages/narwhals/_arrow/expr.py
Normal file
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._compliant import EagerExpr
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._expression_parsing import ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
|
||||
|
||||
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
|
||||
_implementation: Implementation = Implementation.PYARROW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[ArrowDataFrame, ArrowSeries],
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[ArrowDataFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
implementation: Implementation | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._depth = depth
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[ArrowDataFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
try:
|
||||
return [
|
||||
ArrowSeries(
|
||||
df.native[column_name], name=column_name, version=df._version
|
||||
)
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
tbl = df.native
|
||||
cols = df.columns
|
||||
return [
|
||||
ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
|
||||
for i in column_indices
|
||||
]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def __narwhals_namespace__(self) -> ArrowNamespace:
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
return ArrowNamespace(version=self._version)
|
||||
|
||||
def __narwhals_expr__(self) -> None: ...
|
||||
|
||||
def _reuse_series_extra_kwargs(
|
||||
self, *, returns_scalar: bool = False
|
||||
) -> dict[str, Any]:
|
||||
return {"_return_py_scalar": False} if returns_scalar else {}
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
if (
|
||||
partition_by
|
||||
and self._metadata is not None
|
||||
and not self._metadata.is_scalar_like
|
||||
):
|
||||
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if not partition_by:
|
||||
# e.g. `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
assert order_by # noqa: S101
|
||||
|
||||
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
||||
token = generate_temporary_column_name(8, df.columns)
|
||||
df = df.with_row_index(token, order_by=None).sort(
|
||||
*order_by, descending=False, nulls_last=False
|
||||
)
|
||||
result = self(df.drop([token], strict=True))
|
||||
# TODO(marco): is there a way to do this efficiently without
|
||||
# doing 2 sorts? Here we're sorting the dataframe and then
|
||||
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
|
||||
sorting_indices = pc.sort_indices(df.get_column(token).native)
|
||||
return [s._with_native(s.native.take(sorting_indices)) for s in result]
|
||||
else:
|
||||
|
||||
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
if overlap := set(output_names).intersection(partition_by):
|
||||
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
|
||||
# we just don't support it yet.
|
||||
msg = (
|
||||
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
|
||||
"This is not yet supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
|
||||
tmp = df.simple_select(*partition_by).join(
|
||||
tmp,
|
||||
how="left",
|
||||
left_on=partition_by,
|
||||
right_on=partition_by,
|
||||
suffix="_right",
|
||||
)
|
||||
return [tmp.get_column(alias) for alias in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
ewm_mean = not_implemented()
|
||||
159
.venv/lib/python3.12/site-packages/narwhals/_arrow/group_by.py
Normal file
159
.venv/lib/python3.12/site-packages/narwhals/_arrow/group_by.py
Normal file
@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
|
||||
from narwhals._compliant import EagerGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import generate_temporary_column_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
AggregateOptions,
|
||||
Aggregation,
|
||||
Incomplete,
|
||||
)
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals.typing import UniqueKeepStrategy
|
||||
|
||||
|
||||
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "approximate_median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": "stddev",
|
||||
"var": "variance",
|
||||
"len": "count",
|
||||
"n_unique": "count_distinct",
|
||||
"count": "count",
|
||||
"all": "all",
|
||||
"any": "any",
|
||||
}
|
||||
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
|
||||
"any": "min",
|
||||
"first": "min",
|
||||
"last": "max",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: ArrowDataFrame,
|
||||
keys: Sequence[ArrowExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._df = df
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
|
||||
self._drop_null_keys = drop_null_keys
|
||||
|
||||
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
|
||||
self._ensure_all_simple(exprs)
|
||||
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
|
||||
expected_pyarrow_column_names: list[str] = self._keys.copy()
|
||||
new_column_names: list[str] = self._keys.copy()
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
if expr._function_name != "len": # pragma: no cover
|
||||
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_column_names.append(aliases[0])
|
||||
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
|
||||
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
|
||||
continue
|
||||
|
||||
function_name = self._leaf_name(expr)
|
||||
if function_name in {"std", "var"}:
|
||||
assert "ddof" in expr._scalar_kwargs # noqa: S101
|
||||
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
|
||||
elif function_name in {"len", "n_unique"}:
|
||||
option = pc.CountOptions(mode="all")
|
||||
elif function_name == "count":
|
||||
option = pc.CountOptions(mode="only_valid")
|
||||
elif function_name in {"all", "any"}:
|
||||
option = pc.ScalarAggregateOptions(min_count=0)
|
||||
else:
|
||||
option = None
|
||||
|
||||
function_name = self._remap_expr_name(function_name)
|
||||
new_column_names.extend(aliases)
|
||||
expected_pyarrow_column_names.extend(
|
||||
[f"{output_name}_{function_name}" for output_name in output_names]
|
||||
)
|
||||
aggs.extend(
|
||||
[(output_name, function_name, option) for output_name in output_names]
|
||||
)
|
||||
|
||||
result_simple = self._grouped.aggregate(aggs)
|
||||
|
||||
# Rename columns, being very careful
|
||||
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
|
||||
for idx, item in enumerate(expected_pyarrow_column_names):
|
||||
expected_old_names_indices[item].append(idx)
|
||||
if not (
|
||||
set(result_simple.column_names) == set(expected_pyarrow_column_names)
|
||||
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
|
||||
): # pragma: no cover
|
||||
msg = (
|
||||
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
|
||||
f"got {result_simple.column_names}, "
|
||||
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
index_map: list[int] = [
|
||||
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
|
||||
]
|
||||
new_column_names = [new_column_names[i] for i in index_map]
|
||||
result_simple = result_simple.rename_columns(new_column_names)
|
||||
return self.compliant._with_native(result_simple).rename(
|
||||
dict(zip(self._keys, self._output_key_names))
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
|
||||
col_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=self.compliant.columns
|
||||
)
|
||||
null_token: str = "__null_token_value__" # noqa: S105
|
||||
|
||||
table = self.compliant.native
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(table[key] for key in self._keys), separator=""
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
key_values = concat_str(
|
||||
*it, separator_scalar, null_handling="replace", null_replacement=null_token
|
||||
)
|
||||
table = table.add_column(i=0, field_=col_token, column=key_values)
|
||||
|
||||
for v in pc.unique(key_values):
|
||||
t = self.compliant._with_native(
|
||||
table.filter(pc.equal(table[col_token], v)).drop([col_token])
|
||||
)
|
||||
row = t.simple_select(*self._keys).row(0)
|
||||
yield (
|
||||
tuple(extract_py_scalar(el) for el in row),
|
||||
t.simple_select(*self._df.columns),
|
||||
)
|
||||
309
.venv/lib/python3.12/site-packages/narwhals/_arrow/namespace.py
Normal file
309
.venv/lib/python3.12/site-packages/narwhals/_arrow/namespace.py
Normal file
@ -0,0 +1,309 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.selectors import ArrowSelectorNamespace
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types
|
||||
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class ArrowNamespace(
|
||||
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"]
|
||||
):
|
||||
_implementation = Implementation.PYARROW
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[ArrowDataFrame]:
|
||||
return ArrowDataFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[ArrowExpr]:
|
||||
return ArrowExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[ArrowSeries]:
|
||||
return ArrowSeries
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def len(self) -> ArrowExpr:
|
||||
# coverage bug? this is definitely hit
|
||||
return self._expr( # pragma: no cover
|
||||
lambda df: [
|
||||
ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
|
||||
],
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
|
||||
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
|
||||
arrow_series = ArrowSeries.from_iterable(
|
||||
data=[value], name="literal", context=self
|
||||
)
|
||||
if dtype:
|
||||
return arrow_series.cast(dtype)
|
||||
return arrow_series
|
||||
|
||||
return self._expr(
|
||||
lambda df: [_lit_arrow_series(df)],
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
series = chain.from_iterable(expr(df) for expr in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
it = (
|
||||
(s.fill_null(True, None, None) for s in series) # noqa: FBT003
|
||||
if ignore_nulls
|
||||
else series
|
||||
)
|
||||
return [reduce(operator.and_, align(*it))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
series = chain.from_iterable(expr(df) for expr in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
it = (
|
||||
(s.fill_null(False, None, None) for s in series) # noqa: FBT003
|
||||
if ignore_nulls
|
||||
else series
|
||||
)
|
||||
return [reduce(operator.or_, align(*it))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
it = chain.from_iterable(expr(df) for expr in exprs)
|
||||
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
|
||||
align = self._series._align_full_broadcast
|
||||
return [reduce(operator.add, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
int_64 = self._version.dtypes.Int64()
|
||||
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(
|
||||
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
|
||||
)
|
||||
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
|
||||
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
init_series, *series = align(init_series, *series)
|
||||
native_series = reduce(
|
||||
pc.min_element_wise, [s.native for s in series], init_series.native
|
||||
)
|
||||
return [
|
||||
ArrowSeries(native_series, name=init_series.name, version=self._version)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
init_series, *series = align(init_series, *series)
|
||||
native_series = reduce(
|
||||
pc.max_element_wise, [s.native for s in series], init_series.native
|
||||
)
|
||||
return [
|
||||
ArrowSeries(native_series, name=init_series.name, version=self._version)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
if self._backend_version >= (14,):
|
||||
return pa.concat_tables(dfs, promote_options="default")
|
||||
return pa.concat_tables(dfs, promote=True) # pragma: no cover
|
||||
|
||||
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
names = list(chain.from_iterable(df.column_names for df in dfs))
|
||||
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
|
||||
return pa.Table.from_arrays(arrays, names=names)
|
||||
|
||||
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
cols_0 = dfs[0].column_names
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.column_names
|
||||
if cols_current != cols_0:
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0}\n"
|
||||
f" - dataframe {i}: {cols_current}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return pa.concat_tables(dfs)
|
||||
|
||||
@property
|
||||
def selectors(self) -> ArrowSelectorNamespace:
|
||||
return ArrowSelectorNamespace.from_namespace(self)
|
||||
|
||||
def when(self, predicate: ArrowExpr) -> ArrowWhen:
|
||||
return ArrowWhen.from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
|
||||
) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
compliant_series_list = align(
|
||||
*(chain.from_iterable(expr(df) for expr in exprs))
|
||||
)
|
||||
name = compliant_series_list[0].name
|
||||
null_handling: Literal["skip", "emit_null"] = (
|
||||
"skip" if ignore_nulls else "emit_null"
|
||||
)
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(s.native for s in compliant_series_list), separator=separator
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
compliant = self._series(
|
||||
concat_str(*it, separator_scalar, null_handling=null_handling),
|
||||
name=name,
|
||||
version=self._version,
|
||||
)
|
||||
return [compliant]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def coalesce(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = align(*chain.from_iterable(expr(df) for expr in exprs))
|
||||
return [
|
||||
ArrowSeries(
|
||||
pc.coalesce(init_series.native, *(s.native for s in series)),
|
||||
name=init_series.name,
|
||||
version=self._version,
|
||||
)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
|
||||
@property
|
||||
def _then(self) -> type[ArrowThen]:
|
||||
return ArrowThen
|
||||
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: ChunkedArrayAny,
|
||||
then: ChunkedArrayAny,
|
||||
otherwise: ArrayOrScalar | NonNestedLiteral,
|
||||
/,
|
||||
) -> ChunkedArrayAny:
|
||||
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
|
||||
return pc.if_else(when, then, otherwise)
|
||||
|
||||
|
||||
class ArrowThen(
|
||||
CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr
|
||||
):
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
||||
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401
|
||||
from narwhals._arrow.series import ArrowSeries # noqa: F401
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
|
||||
|
||||
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
|
||||
@property
|
||||
def _selector(self) -> type[ArrowSelector]:
|
||||
return ArrowSelector
|
||||
|
||||
|
||||
class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> ArrowExpr:
|
||||
return ArrowExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
1183
.venv/lib/python3.12/site-packages/narwhals/_arrow/series.py
Normal file
1183
.venv/lib/python3.12/site-packages/narwhals/_arrow/series.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import Incomplete
|
||||
|
||||
|
||||
class ArrowSeriesCatNamespace(ArrowSeriesNamespace):
|
||||
def get_categories(self) -> ArrowSeries:
|
||||
# NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes
|
||||
chunks: Incomplete = self.native.chunks
|
||||
return self.with_native(pa.concat_arrays(x.dictionary for x in chunks).unique())
|
||||
227
.venv/lib/python3.12/site-packages/narwhals/_arrow/series_dt.py
Normal file
227
.venv/lib/python3.12/site-packages/narwhals/_arrow/series_dt.py
Normal file
@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit
|
||||
from narwhals._constants import (
|
||||
MS_PER_MINUTE,
|
||||
MS_PER_SECOND,
|
||||
NS_PER_MICROSECOND,
|
||||
NS_PER_MILLISECOND,
|
||||
NS_PER_MINUTE,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_DAY,
|
||||
SECONDS_PER_MINUTE,
|
||||
US_PER_MINUTE,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._duration import Interval
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny
|
||||
from narwhals.dtypes import Datetime
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
UnitCurrent: TypeAlias = TimeUnit
|
||||
UnitTarget: TypeAlias = TimeUnit
|
||||
BinOpBroadcast: TypeAlias = Callable[[ChunkedArrayAny, ScalarAny], ChunkedArrayAny]
|
||||
IntoRhs: TypeAlias = int
|
||||
|
||||
|
||||
class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace):
|
||||
_TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = {
|
||||
"ns": NS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ms": MS_PER_SECOND,
|
||||
"s": 1,
|
||||
}
|
||||
_TIMESTAMP_DATETIME_OP_FACTOR: ClassVar[
|
||||
Mapping[tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]]
|
||||
] = {
|
||||
("ns", "us"): (floordiv_compat, 1_000),
|
||||
("ns", "ms"): (floordiv_compat, 1_000_000),
|
||||
("us", "ns"): (pc.multiply, NS_PER_MICROSECOND),
|
||||
("us", "ms"): (floordiv_compat, 1_000),
|
||||
("ms", "ns"): (pc.multiply, NS_PER_MILLISECOND),
|
||||
("ms", "us"): (pc.multiply, 1_000),
|
||||
("s", "ns"): (pc.multiply, NS_PER_SECOND),
|
||||
("s", "us"): (pc.multiply, US_PER_SECOND),
|
||||
("s", "ms"): (pc.multiply, MS_PER_SECOND),
|
||||
}
|
||||
|
||||
@property
|
||||
def unit(self) -> TimeUnit: # NOTE: Unsafe (native).
|
||||
return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit
|
||||
|
||||
@property
|
||||
def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals).
|
||||
return cast("Datetime", self.compliant.dtype).time_zone
|
||||
|
||||
def to_string(self, format: str) -> ArrowSeries:
|
||||
# PyArrow differs from other libraries in that %S also prints out
|
||||
# the fractional part of the second...:'(
|
||||
# https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html
|
||||
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
|
||||
return self.with_native(pc.strftime(self.native, format))
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> ArrowSeries:
|
||||
if time_zone is not None:
|
||||
result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone)
|
||||
else:
|
||||
result = pc.local_timestamp(self.native)
|
||||
return self.with_native(result)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> ArrowSeries:
|
||||
ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant
|
||||
return self.with_native(ser.native.cast(pa.timestamp(self.unit, time_zone)))
|
||||
|
||||
def timestamp(self, time_unit: TimeUnit) -> ArrowSeries:
|
||||
ser = self.compliant
|
||||
dtypes = ser._version.dtypes
|
||||
if isinstance(ser.dtype, dtypes.Datetime):
|
||||
current = ser.dtype.time_unit
|
||||
s_cast = self.native.cast(pa.int64())
|
||||
if current == time_unit:
|
||||
result = s_cast
|
||||
elif item := self._TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
|
||||
fn, factor = item
|
||||
result = fn(s_cast, lit(factor))
|
||||
else: # pragma: no cover
|
||||
msg = f"unexpected time unit {current}, please report an issue at https://github.com/narwhals-dev/narwhals"
|
||||
raise AssertionError(msg)
|
||||
return self.with_native(result)
|
||||
elif isinstance(ser.dtype, dtypes.Date):
|
||||
time_s = pc.multiply(self.native.cast(pa.int32()), lit(SECONDS_PER_DAY))
|
||||
factor = self._TIMESTAMP_DATE_FACTOR[time_unit]
|
||||
return self.with_native(pc.multiply(time_s, lit(factor)))
|
||||
else:
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
|
||||
def date(self) -> ArrowSeries:
|
||||
return self.with_native(self.native.cast(pa.date32()))
|
||||
|
||||
def year(self) -> ArrowSeries:
|
||||
return self.with_native(pc.year(self.native))
|
||||
|
||||
def month(self) -> ArrowSeries:
|
||||
return self.with_native(pc.month(self.native))
|
||||
|
||||
def day(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day(self.native))
|
||||
|
||||
def hour(self) -> ArrowSeries:
|
||||
return self.with_native(pc.hour(self.native))
|
||||
|
||||
def minute(self) -> ArrowSeries:
|
||||
return self.with_native(pc.minute(self.native))
|
||||
|
||||
def second(self) -> ArrowSeries:
|
||||
return self.with_native(pc.second(self.native))
|
||||
|
||||
def millisecond(self) -> ArrowSeries:
|
||||
return self.with_native(pc.millisecond(self.native))
|
||||
|
||||
def microsecond(self) -> ArrowSeries:
|
||||
arr = self.native
|
||||
result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr))
|
||||
return self.with_native(result)
|
||||
|
||||
def nanosecond(self) -> ArrowSeries:
|
||||
result = pc.add(
|
||||
pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native)
|
||||
)
|
||||
return self.with_native(result)
|
||||
|
||||
def ordinal_day(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day_of_year(self.native))
|
||||
|
||||
def weekday(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day_of_week(self.native, count_from_zero=False))
|
||||
|
||||
def total_minutes(self) -> ArrowSeries:
|
||||
unit_to_minutes_factor = {
|
||||
"s": SECONDS_PER_MINUTE,
|
||||
"ms": MS_PER_MINUTE,
|
||||
"us": US_PER_MINUTE,
|
||||
"ns": NS_PER_MINUTE,
|
||||
}
|
||||
factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_seconds(self) -> ArrowSeries:
|
||||
unit_to_seconds_factor = {
|
||||
"s": 1,
|
||||
"ms": MS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ns": NS_PER_SECOND,
|
||||
}
|
||||
factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_milliseconds(self) -> ArrowSeries:
|
||||
unit_to_milli_factor = {
|
||||
"s": 1e3, # seconds
|
||||
"ms": 1, # milli
|
||||
"us": 1e3, # micro
|
||||
"ns": 1e6, # nano
|
||||
}
|
||||
factor = lit(unit_to_milli_factor[self.unit], type=pa.int64())
|
||||
if self.unit == "s":
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_microseconds(self) -> ArrowSeries:
|
||||
unit_to_micro_factor = {
|
||||
"s": 1e6, # seconds
|
||||
"ms": 1e3, # milli
|
||||
"us": 1, # micro
|
||||
"ns": 1e3, # nano
|
||||
}
|
||||
factor = lit(unit_to_micro_factor[self.unit], type=pa.int64())
|
||||
if self.unit in {"s", "ms"}:
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_nanoseconds(self) -> ArrowSeries:
|
||||
unit_to_nano_factor = {
|
||||
"s": NS_PER_SECOND,
|
||||
"ms": NS_PER_MILLISECOND,
|
||||
"us": NS_PER_MICROSECOND,
|
||||
"ns": 1,
|
||||
}
|
||||
factor = lit(unit_to_nano_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def truncate(self, every: str) -> ArrowSeries:
|
||||
interval = Interval.parse(every)
|
||||
return self.with_native(
|
||||
pc.floor_temporal(self.native, interval.multiple, UNITS_DICT[interval.unit])
|
||||
)
|
||||
|
||||
def offset_by(self, by: str) -> ArrowSeries:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
native = self.native
|
||||
if interval.unit in {"y", "q", "mo"}:
|
||||
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
|
||||
raise NotImplementedError(msg)
|
||||
dtype = self.compliant.dtype
|
||||
datetime_dtype = self.version.dtypes.Datetime
|
||||
if interval.unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
|
||||
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
|
||||
native_naive = pc.local_timestamp(native)
|
||||
result = pc.assume_timezone(pc.add(native_naive, offset), dtype.time_zone)
|
||||
return self.with_native(result)
|
||||
if interval.unit == "ns": # pragma: no cover
|
||||
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
|
||||
else:
|
||||
offset = lit(interval.to_timedelta())
|
||||
return self.with_native(pc.add(native, offset))
|
||||
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
|
||||
class ArrowSeriesListNamespace(ArrowSeriesNamespace):
|
||||
def len(self) -> ArrowSeries:
|
||||
return self.with_native(pc.list_value_length(self.native).cast(pa.uint32()))
|
||||
103
.venv/lib/python3.12/site-packages/narwhals/_arrow/series_str.py
Normal file
103
.venv/lib/python3.12/site-packages/narwhals/_arrow/series_str.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import Incomplete
|
||||
|
||||
|
||||
class ArrowSeriesStringNamespace(ArrowSeriesNamespace):
|
||||
def len_chars(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_length(self.native))
|
||||
|
||||
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries:
|
||||
fn = pc.replace_substring if literal else pc.replace_substring_regex
|
||||
arr = fn(self.native, pattern, replacement=value, max_replacements=n)
|
||||
return self.with_native(arr)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries:
|
||||
return self.replace(pattern, value, literal=literal, n=-1)
|
||||
|
||||
def strip_chars(self, characters: str | None) -> ArrowSeries:
|
||||
return self.with_native(
|
||||
pc.utf8_trim(self.native, characters or string.whitespace)
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> ArrowSeries:
|
||||
return self.with_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix)))
|
||||
|
||||
def ends_with(self, suffix: str) -> ArrowSeries:
|
||||
return self.with_native(
|
||||
pc.equal(self.slice(-len(suffix), None).native, lit(suffix))
|
||||
)
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> ArrowSeries:
|
||||
check_func = pc.match_substring if literal else pc.match_substring_regex
|
||||
return self.with_native(check_func(self.native, pattern))
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> ArrowSeries:
|
||||
stop = offset + length if length is not None else None
|
||||
return self.with_native(
|
||||
pc.utf8_slice_codeunits(self.native, start=offset, stop=stop)
|
||||
)
|
||||
|
||||
def split(self, by: str) -> ArrowSeries:
|
||||
split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload]
|
||||
return self.with_native(split_series)
|
||||
|
||||
def to_datetime(self, format: str | None) -> ArrowSeries:
|
||||
format = parse_datetime_format(self.native) if format is None else format
|
||||
timestamp_array = pc.strptime(self.native, format=format, unit="us")
|
||||
return self.with_native(timestamp_array)
|
||||
|
||||
def to_date(self, format: str | None) -> ArrowSeries:
|
||||
return self.to_datetime(format=format).dt.date()
|
||||
|
||||
def to_uppercase(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_upper(self.native))
|
||||
|
||||
def to_lowercase(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_lower(self.native))
|
||||
|
||||
def zfill(self, width: int) -> ArrowSeries:
|
||||
binary_join: Incomplete = pc.binary_join_element_wise
|
||||
native = self.native
|
||||
hyphen, plus = lit("-"), lit("+")
|
||||
first_char, remaining_chars = (
|
||||
self.slice(0, 1).native,
|
||||
self.slice(1, None).native,
|
||||
)
|
||||
|
||||
# Conditions
|
||||
less_than_width = pc.less(pc.utf8_length(native), lit(width))
|
||||
starts_with_hyphen = pc.equal(first_char, hyphen)
|
||||
starts_with_plus = pc.equal(first_char, plus)
|
||||
|
||||
conditions = pc.make_struct(
|
||||
pc.and_(starts_with_hyphen, less_than_width),
|
||||
pc.and_(starts_with_plus, less_than_width),
|
||||
less_than_width,
|
||||
)
|
||||
|
||||
# Cases
|
||||
padded_remaining_chars = pc.utf8_lpad(remaining_chars, width - 1, padding="0")
|
||||
|
||||
result = pc.case_when(
|
||||
conditions,
|
||||
binary_join(
|
||||
pa.repeat(hyphen, len(native)), padded_remaining_chars, ""
|
||||
), # starts with hyphen and less than width
|
||||
binary_join(
|
||||
pa.repeat(plus, len(native)), padded_remaining_chars, ""
|
||||
), # starts with plus and less than width
|
||||
pc.utf8_lpad(native, width=width, padding="0"), # less than width
|
||||
native,
|
||||
)
|
||||
return self.with_native(result)
|
||||
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
|
||||
class ArrowSeriesStructNamespace(ArrowSeriesNamespace):
|
||||
def field(self, name: str) -> ArrowSeries:
|
||||
return self.with_native(pc.struct_field(self.native, name)).alias(name)
|
||||
72
.venv/lib/python3.12/site-packages/narwhals/_arrow/typing.py
Normal file
72
.venv/lib/python3.12/site-packages/narwhals/_arrow/typing.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING, # pragma: no cover
|
||||
Any, # pragma: no cover
|
||||
TypeVar, # pragma: no cover
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
from typing import Generic, Literal
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow.__lib_pxi.table import (
|
||||
AggregateOptions, # noqa: F401
|
||||
Aggregation, # noqa: F401
|
||||
)
|
||||
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource]
|
||||
Indices, # noqa: F401
|
||||
Mask, # noqa: F401
|
||||
Order, # noqa: F401
|
||||
)
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries"
|
||||
TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"]
|
||||
NullPlacement: TypeAlias = Literal["at_start", "at_end"]
|
||||
NativeIntervalUnit: TypeAlias = Literal[
|
||||
"year",
|
||||
"quarter",
|
||||
"month",
|
||||
"week",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
"millisecond",
|
||||
"microsecond",
|
||||
"nanosecond",
|
||||
]
|
||||
|
||||
ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any]
|
||||
ArrayAny: TypeAlias = pa.Array[Any]
|
||||
ArrayOrChunkedArray: TypeAlias = "ArrayAny | ChunkedArrayAny"
|
||||
ScalarAny: TypeAlias = pa.Scalar[Any]
|
||||
ArrayOrScalar: TypeAlias = "ArrayOrChunkedArray | ScalarAny"
|
||||
ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrayAny, ChunkedArrayAny, ScalarAny)
|
||||
ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrayAny, ChunkedArrayAny, ScalarAny)
|
||||
_AsPyType = TypeVar("_AsPyType")
|
||||
|
||||
class _BasicDataType(pa.DataType, Generic[_AsPyType]): ...
|
||||
|
||||
|
||||
Incomplete: TypeAlias = Any # pragma: no cover
|
||||
"""
|
||||
Marker for working code that fails on the stubs.
|
||||
|
||||
Common issues:
|
||||
- Annotated for `Array`, but not `ChunkedArray`
|
||||
- Relies on typing information that the stubs don't provide statically
|
||||
- Missing attributes
|
||||
- Incorrect return types
|
||||
- Inconsistent use of generic/concrete types
|
||||
- `_clone_signature` used on signatures that are not identical
|
||||
"""
|
||||
443
.venv/lib/python3.12/site-packages/narwhals/_arrow/utils.py
Normal file
443
.venv/lib/python3.12/site-packages/narwhals/_arrow/utils.py
Normal file
@ -0,0 +1,443 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._compliant import EagerSeriesNamespace
|
||||
from narwhals._utils import isinstance_or_issubclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import (
|
||||
ArrayAny,
|
||||
ArrayOrScalar,
|
||||
ArrayOrScalarT1,
|
||||
ArrayOrScalarT2,
|
||||
ChunkedArrayAny,
|
||||
NativeIntervalUnit,
|
||||
ScalarAny,
|
||||
)
|
||||
from narwhals._duration import IntervalUnit
|
||||
from narwhals._utils import Version
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType, PythonLiteral
|
||||
|
||||
# NOTE: stubs don't allow for `ChunkedArray[StructArray]`
|
||||
# Intended to represent the `.chunks` property storing `list[pa.StructArray]`
|
||||
ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny
|
||||
|
||||
def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
|
||||
def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
|
||||
def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
|
||||
def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
|
||||
def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
|
||||
def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
|
||||
def extract_regex(
|
||||
strings: ChunkedArrayAny,
|
||||
/,
|
||||
pattern: str,
|
||||
*,
|
||||
options: Any = None,
|
||||
memory_pool: Any = None,
|
||||
) -> ChunkedArrayStructArray: ...
|
||||
else:
|
||||
from pyarrow.compute import extract_regex
|
||||
from pyarrow.types import (
|
||||
is_dictionary, # noqa: F401
|
||||
is_duration,
|
||||
is_fixed_size_list,
|
||||
is_large_list,
|
||||
is_list,
|
||||
is_timestamp,
|
||||
)
|
||||
|
||||
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
lit = pa.scalar
|
||||
"""Alias for `pyarrow.scalar`."""
|
||||
|
||||
|
||||
def extract_py_scalar(value: Any, /) -> Any:
|
||||
from narwhals._arrow.series import maybe_extract_py_scalar
|
||||
|
||||
return maybe_extract_py_scalar(value, return_py_scalar=True)
|
||||
|
||||
|
||||
def chunked_array(
|
||||
arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, /
|
||||
) -> ChunkedArrayAny:
|
||||
if isinstance(arr, pa.ChunkedArray):
|
||||
return arr
|
||||
if isinstance(arr, list):
|
||||
return pa.chunked_array(arr, dtype)
|
||||
else:
|
||||
return pa.chunked_array([arr], arr.type)
|
||||
|
||||
|
||||
def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
|
||||
"""Create a strongly-typed Array instance with all elements null.
|
||||
|
||||
Uses the type of `series`, without upseting `mypy`.
|
||||
"""
|
||||
return pa.nulls(n, series.native.type)
|
||||
|
||||
|
||||
def zeros(n: int, /) -> pa.Int64Array:
|
||||
return pa.repeat(0, n)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
|
||||
dtypes = version.dtypes
|
||||
if pa.types.is_int64(dtype):
|
||||
return dtypes.Int64()
|
||||
if pa.types.is_int32(dtype):
|
||||
return dtypes.Int32()
|
||||
if pa.types.is_int16(dtype):
|
||||
return dtypes.Int16()
|
||||
if pa.types.is_int8(dtype):
|
||||
return dtypes.Int8()
|
||||
if pa.types.is_uint64(dtype):
|
||||
return dtypes.UInt64()
|
||||
if pa.types.is_uint32(dtype):
|
||||
return dtypes.UInt32()
|
||||
if pa.types.is_uint16(dtype):
|
||||
return dtypes.UInt16()
|
||||
if pa.types.is_uint8(dtype):
|
||||
return dtypes.UInt8()
|
||||
if pa.types.is_boolean(dtype):
|
||||
return dtypes.Boolean()
|
||||
if pa.types.is_float64(dtype):
|
||||
return dtypes.Float64()
|
||||
if pa.types.is_float32(dtype):
|
||||
return dtypes.Float32()
|
||||
# bug in coverage? it shows `31->exit` (where `31` is currently the line number of
|
||||
# the next line), even though both when the if condition is true and false are covered
|
||||
if ( # pragma: no cover
|
||||
pa.types.is_string(dtype)
|
||||
or pa.types.is_large_string(dtype)
|
||||
or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
|
||||
):
|
||||
return dtypes.String()
|
||||
if pa.types.is_date32(dtype):
|
||||
return dtypes.Date()
|
||||
if is_timestamp(dtype):
|
||||
return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
|
||||
if is_duration(dtype):
|
||||
return dtypes.Duration(time_unit=dtype.unit)
|
||||
if pa.types.is_dictionary(dtype):
|
||||
return dtypes.Categorical()
|
||||
if pa.types.is_struct(dtype):
|
||||
return dtypes.Struct(
|
||||
[
|
||||
dtypes.Field(
|
||||
dtype.field(i).name,
|
||||
native_to_narwhals_dtype(dtype.field(i).type, version),
|
||||
)
|
||||
for i in range(dtype.num_fields)
|
||||
]
|
||||
)
|
||||
if is_list(dtype) or is_large_list(dtype):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
|
||||
if is_fixed_size_list(dtype):
|
||||
return dtypes.Array(
|
||||
native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
|
||||
)
|
||||
if pa.types.is_decimal(dtype):
|
||||
return dtypes.Decimal()
|
||||
if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
|
||||
return dtypes.Time()
|
||||
if pa.types.is_binary(dtype):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912
|
||||
dtypes = version.dtypes
|
||||
if isinstance_or_issubclass(dtype, dtypes.Decimal):
|
||||
msg = "Casting to Decimal is not supported yet."
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Float64):
|
||||
return pa.float64()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Float32):
|
||||
return pa.float32()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int64):
|
||||
return pa.int64()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int32):
|
||||
return pa.int32()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int16):
|
||||
return pa.int16()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int8):
|
||||
return pa.int8()
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt64):
|
||||
return pa.uint64()
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt32):
|
||||
return pa.uint32()
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt16):
|
||||
return pa.uint16()
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt8):
|
||||
return pa.uint8()
|
||||
if isinstance_or_issubclass(dtype, dtypes.String):
|
||||
return pa.string()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Boolean):
|
||||
return pa.bool_()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Categorical):
|
||||
return pa.dictionary(pa.uint32(), pa.string())
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
unit = dtype.time_unit
|
||||
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pa.duration(dtype.time_unit)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Date):
|
||||
return pa.date32()
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
return pa.struct(
|
||||
[
|
||||
(field.name, narwhals_to_native_dtype(field.dtype, version=version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version=version)
|
||||
list_size = dtype.size
|
||||
return pa.list_(inner, list_size=list_size)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Time):
|
||||
return pa.time64("ns")
|
||||
if isinstance_or_issubclass(dtype, dtypes.Binary):
|
||||
return pa.binary()
|
||||
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def extract_native(
|
||||
lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny
|
||||
) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]:
|
||||
"""Extract native objects in binary operation.
|
||||
|
||||
If the comparison isn't supported, return `NotImplemented` so that the
|
||||
"right-hand-side" operation (e.g. `__radd__`) can be tried.
|
||||
|
||||
If one of the two sides has a `_broadcast` flag, then extract the scalar
|
||||
underneath it so that PyArrow can do its own broadcasting.
|
||||
"""
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
if rhs is None: # pragma: no cover
|
||||
return lhs.native, lit(None, type=lhs._type)
|
||||
|
||||
if isinstance(rhs, ArrowSeries):
|
||||
if lhs._broadcast and not rhs._broadcast:
|
||||
return lhs.native[0], rhs.native
|
||||
if rhs._broadcast:
|
||||
return lhs.native, rhs.native[0]
|
||||
return lhs.native, rhs.native
|
||||
|
||||
if isinstance(rhs, list):
|
||||
msg = "Expected Series or scalar, got list."
|
||||
raise TypeError(msg)
|
||||
|
||||
return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
|
||||
|
||||
|
||||
def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
|
||||
# The following lines are adapted from pandas' pyarrow implementation.
|
||||
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
|
||||
|
||||
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
|
||||
divided = pc.divide_checked(left, right)
|
||||
# TODO @dangotbanned: Use a `TypeVar` in guards
|
||||
# Narrowing to a `Union` isn't interacting well with the rest of the stubs
|
||||
# https://github.com/zen-xu/pyarrow-stubs/pull/215
|
||||
if pa.types.is_signed_integer(divided.type):
|
||||
div_type = cast("pa._lib.Int64Type", divided.type)
|
||||
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
|
||||
has_one_negative_operand = pc.less(
|
||||
pc.bit_wise_xor(left, right), lit(0, div_type)
|
||||
)
|
||||
result = pc.if_else(
|
||||
pc.and_(has_remainder, has_one_negative_operand),
|
||||
pc.subtract(divided, lit(1, div_type)),
|
||||
divided,
|
||||
)
|
||||
else:
|
||||
result = divided # pragma: no cover
|
||||
result = result.cast(left.type)
|
||||
else:
|
||||
divided = pc.divide(left, right)
|
||||
result = pc.floor(divided)
|
||||
return result
|
||||
|
||||
|
||||
def cast_for_truediv(
|
||||
arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2
|
||||
) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]:
|
||||
# Lifted from:
|
||||
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
|
||||
# Ensure int / int -> float mirroring Python/Numpy behavior
|
||||
# as pc.divide_checked(int, int) -> int
|
||||
if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type):
|
||||
# GH: 56645. # noqa: ERA001
|
||||
# https://github.com/apache/arrow/issues/35563
|
||||
return arrow_array.cast(pa.float64(), safe=False), pa_object.cast(
|
||||
pa.float64(), safe=False
|
||||
)
|
||||
|
||||
return arrow_array, pa_object
|
||||
|
||||
|
||||
# Regex for date, time, separator and timezone components
|
||||
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})"
|
||||
SEP_RE = r"(?P<sep>\s|T)"
|
||||
TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)?
|
||||
HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$"
|
||||
HM_RE = r"^(?P<hm>\d{2}:\d{2})$"
|
||||
HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$"
|
||||
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
|
||||
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"
|
||||
|
||||
# Separate regexes for different date formats
|
||||
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
|
||||
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
|
||||
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
|
||||
YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$"
|
||||
|
||||
DATE_FORMATS = (
|
||||
(YMD_RE_NO_SEP, "%Y%m%d"),
|
||||
(YMD_RE, "%Y-%m-%d"),
|
||||
(DMY_RE, "%d-%m-%Y"),
|
||||
(MDY_RE, "%m-%d-%Y"),
|
||||
)
|
||||
TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S"))
|
||||
|
||||
|
||||
def _extract_regex_concat_arrays(
|
||||
strings: ChunkedArrayAny,
|
||||
/,
|
||||
pattern: str,
|
||||
*,
|
||||
options: Any = None,
|
||||
memory_pool: Any = None,
|
||||
) -> pa.StructArray:
|
||||
r = pa.concat_arrays(
|
||||
extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks
|
||||
)
|
||||
return cast("pa.StructArray", r)
|
||||
|
||||
|
||||
def parse_datetime_format(arr: ChunkedArrayAny) -> str:
|
||||
"""Try to infer datetime format from StringArray."""
|
||||
matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE)
|
||||
if not pc.all(matches.is_valid()).as_py():
|
||||
msg = (
|
||||
"Unable to infer datetime format, provided format is not supported. "
|
||||
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
separators = matches.field("sep")
|
||||
tz = matches.field("tz")
|
||||
|
||||
# separators and time zones must be unique
|
||||
if pc.count(pc.unique(separators)).as_py() > 1:
|
||||
msg = "Found multiple separator values while inferring datetime format."
|
||||
raise ValueError(msg)
|
||||
|
||||
if pc.count(pc.unique(tz)).as_py() > 1:
|
||||
msg = "Found multiple timezone values while inferring datetime format."
|
||||
raise ValueError(msg)
|
||||
|
||||
date_value = _parse_date_format(cast("pc.StringArray", matches.field("date")))
|
||||
time_value = _parse_time_format(cast("pc.StringArray", matches.field("time")))
|
||||
|
||||
sep_value = separators[0].as_py()
|
||||
tz_value = "%z" if tz[0].as_py() else ""
|
||||
|
||||
return f"{date_value}{sep_value}{time_value}{tz_value}"
|
||||
|
||||
|
||||
def _parse_date_format(arr: pc.StringArray) -> str:
|
||||
for date_rgx, date_fmt in DATE_FORMATS:
|
||||
matches = pc.extract_regex(arr, pattern=date_rgx)
|
||||
if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py():
|
||||
return date_fmt
|
||||
elif (
|
||||
pc.all(matches.is_valid()).as_py()
|
||||
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
|
||||
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
|
||||
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
|
||||
):
|
||||
return date_fmt.replace("-", date_sep_value)
|
||||
|
||||
msg = (
|
||||
"Unable to infer datetime format. "
|
||||
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _parse_time_format(arr: pc.StringArray) -> str:
|
||||
for time_rgx, time_fmt in TIME_FORMATS:
|
||||
matches = pc.extract_regex(arr, pattern=time_rgx)
|
||||
if pc.all(matches.is_valid()).as_py():
|
||||
return time_fmt
|
||||
return ""
|
||||
|
||||
|
||||
def pad_series(
|
||||
series: ArrowSeries, *, window_size: int, center: bool
|
||||
) -> tuple[ArrowSeries, int]:
|
||||
"""Pad series with None values on the left and/or right side, depending on the specified parameters.
|
||||
|
||||
Arguments:
|
||||
series: The input ArrowSeries to be padded.
|
||||
window_size: The desired size of the window.
|
||||
center: Specifies whether to center the padding or not.
|
||||
|
||||
Returns:
|
||||
A tuple containing the padded ArrowSeries and the offset value.
|
||||
"""
|
||||
if not center:
|
||||
return series, 0
|
||||
offset_left = window_size // 2
|
||||
# subtract one if window_size is even
|
||||
offset_right = offset_left - (window_size % 2 == 0)
|
||||
pad_left = pa.array([None] * offset_left, type=series._type)
|
||||
pad_right = pa.array([None] * offset_right, type=series._type)
|
||||
concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
|
||||
return series._with_native(concat), offset_left + offset_right
|
||||
|
||||
|
||||
def cast_to_comparable_string_types(
|
||||
*chunked_arrays: ChunkedArrayAny, separator: str
|
||||
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
|
||||
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
|
||||
dtype = (
|
||||
pa.string() # (PyArrow default)
|
||||
if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
|
||||
else pa.large_string()
|
||||
)
|
||||
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
|
||||
|
||||
|
||||
class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...
|
||||
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from narwhals._compliant.dataframe import (
|
||||
CompliantDataFrame,
|
||||
CompliantLazyFrame,
|
||||
EagerDataFrame,
|
||||
)
|
||||
from narwhals._compliant.expr import (
|
||||
CompliantExpr,
|
||||
DepthTrackingExpr,
|
||||
EagerExpr,
|
||||
LazyExpr,
|
||||
LazyExprNamespace,
|
||||
)
|
||||
from narwhals._compliant.group_by import (
|
||||
CompliantGroupBy,
|
||||
DepthTrackingGroupBy,
|
||||
EagerGroupBy,
|
||||
)
|
||||
from narwhals._compliant.namespace import (
|
||||
CompliantNamespace,
|
||||
DepthTrackingNamespace,
|
||||
EagerNamespace,
|
||||
LazyNamespace,
|
||||
)
|
||||
from narwhals._compliant.selectors import (
|
||||
CompliantSelector,
|
||||
CompliantSelectorNamespace,
|
||||
EagerSelectorNamespace,
|
||||
LazySelectorNamespace,
|
||||
)
|
||||
from narwhals._compliant.series import (
|
||||
CompliantSeries,
|
||||
EagerSeries,
|
||||
EagerSeriesCatNamespace,
|
||||
EagerSeriesDateTimeNamespace,
|
||||
EagerSeriesHist,
|
||||
EagerSeriesListNamespace,
|
||||
EagerSeriesNamespace,
|
||||
EagerSeriesStringNamespace,
|
||||
EagerSeriesStructNamespace,
|
||||
)
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT,
|
||||
CompliantFrameT,
|
||||
CompliantSeriesOrNativeExprT_co,
|
||||
CompliantSeriesT,
|
||||
EagerDataFrameT,
|
||||
EagerSeriesT,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
IntoCompliantExpr,
|
||||
NativeFrameT_co,
|
||||
NativeSeriesT_co,
|
||||
)
|
||||
from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
|
||||
__all__ = [
|
||||
"CompliantDataFrame",
|
||||
"CompliantExpr",
|
||||
"CompliantExprT",
|
||||
"CompliantFrameT",
|
||||
"CompliantGroupBy",
|
||||
"CompliantLazyFrame",
|
||||
"CompliantNamespace",
|
||||
"CompliantSelector",
|
||||
"CompliantSelectorNamespace",
|
||||
"CompliantSeries",
|
||||
"CompliantSeriesOrNativeExprT_co",
|
||||
"CompliantSeriesT",
|
||||
"CompliantThen",
|
||||
"CompliantWhen",
|
||||
"DepthTrackingExpr",
|
||||
"DepthTrackingGroupBy",
|
||||
"DepthTrackingNamespace",
|
||||
"EagerDataFrame",
|
||||
"EagerDataFrameT",
|
||||
"EagerExpr",
|
||||
"EagerGroupBy",
|
||||
"EagerNamespace",
|
||||
"EagerSelectorNamespace",
|
||||
"EagerSeries",
|
||||
"EagerSeriesCatNamespace",
|
||||
"EagerSeriesDateTimeNamespace",
|
||||
"EagerSeriesHist",
|
||||
"EagerSeriesListNamespace",
|
||||
"EagerSeriesNamespace",
|
||||
"EagerSeriesStringNamespace",
|
||||
"EagerSeriesStructNamespace",
|
||||
"EagerSeriesT",
|
||||
"EagerWhen",
|
||||
"EvalNames",
|
||||
"EvalSeries",
|
||||
"IntoCompliantExpr",
|
||||
"LazyExpr",
|
||||
"LazyExprNamespace",
|
||||
"LazyNamespace",
|
||||
"LazySelectorNamespace",
|
||||
"NativeFrameT_co",
|
||||
"NativeSeriesT_co",
|
||||
"WindowInputs",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,89 @@
|
||||
"""`Expr` and `Series` namespace accessor protocols."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from narwhals._utils import CompliantT_co, _StoresCompliant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
__all__ = [
|
||||
"CatNamespace",
|
||||
"DateTimeNamespace",
|
||||
"ListNamespace",
|
||||
"NameNamespace",
|
||||
"StringNamespace",
|
||||
"StructNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def get_categories(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def to_string(self, format: str) -> CompliantT_co: ...
|
||||
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
|
||||
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
|
||||
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
|
||||
def date(self) -> CompliantT_co: ...
|
||||
def year(self) -> CompliantT_co: ...
|
||||
def month(self) -> CompliantT_co: ...
|
||||
def day(self) -> CompliantT_co: ...
|
||||
def hour(self) -> CompliantT_co: ...
|
||||
def minute(self) -> CompliantT_co: ...
|
||||
def second(self) -> CompliantT_co: ...
|
||||
def millisecond(self) -> CompliantT_co: ...
|
||||
def microsecond(self) -> CompliantT_co: ...
|
||||
def nanosecond(self) -> CompliantT_co: ...
|
||||
def ordinal_day(self) -> CompliantT_co: ...
|
||||
def weekday(self) -> CompliantT_co: ...
|
||||
def total_minutes(self) -> CompliantT_co: ...
|
||||
def total_seconds(self) -> CompliantT_co: ...
|
||||
def total_milliseconds(self) -> CompliantT_co: ...
|
||||
def total_microseconds(self) -> CompliantT_co: ...
|
||||
def total_nanoseconds(self) -> CompliantT_co: ...
|
||||
def truncate(self, every: str) -> CompliantT_co: ...
|
||||
def offset_by(self, by: str) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def len(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def keep(self) -> CompliantT_co: ...
|
||||
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
|
||||
def prefix(self, prefix: str) -> CompliantT_co: ...
|
||||
def suffix(self, suffix: str) -> CompliantT_co: ...
|
||||
def to_lowercase(self) -> CompliantT_co: ...
|
||||
def to_uppercase(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def len_chars(self) -> CompliantT_co: ...
|
||||
def replace(
|
||||
self, pattern: str, value: str, *, literal: bool, n: int
|
||||
) -> CompliantT_co: ...
|
||||
def replace_all(
|
||||
self, pattern: str, value: str, *, literal: bool
|
||||
) -> CompliantT_co: ...
|
||||
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
|
||||
def starts_with(self, prefix: str) -> CompliantT_co: ...
|
||||
def ends_with(self, suffix: str) -> CompliantT_co: ...
|
||||
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
|
||||
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
|
||||
def split(self, by: str) -> CompliantT_co: ...
|
||||
def to_datetime(self, format: str | None) -> CompliantT_co: ...
|
||||
def to_date(self, format: str | None) -> CompliantT_co: ...
|
||||
def to_lowercase(self) -> CompliantT_co: ...
|
||||
def to_uppercase(self) -> CompliantT_co: ...
|
||||
def zfill(self, width: int) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def field(self, name: str) -> CompliantT_co: ...
|
||||
@ -0,0 +1,486 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping, Sequence, Sized
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantExprT_contra,
|
||||
CompliantLazyFrameAny,
|
||||
CompliantSeriesT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
NativeFrameT,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals._translate import (
|
||||
ArrowConvertible,
|
||||
DictConvertible,
|
||||
FromNative,
|
||||
NumpyConvertible,
|
||||
ToNarwhals,
|
||||
ToNarwhalsT_co,
|
||||
)
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
ValidateBackendVersion,
|
||||
Version,
|
||||
_StoresNative,
|
||||
check_columns_exist,
|
||||
is_compliant_series,
|
||||
is_index_selector,
|
||||
is_range,
|
||||
is_sequence_like,
|
||||
is_sized_multi_index_selector,
|
||||
is_slice_index,
|
||||
is_slice_none,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
|
||||
from narwhals._compliant.namespace import EagerNamespace
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._utils import Implementation, _LimitedContext
|
||||
from narwhals.dataframe import DataFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.typing import (
|
||||
AsofJoinStrategy,
|
||||
JoinStrategy,
|
||||
LazyUniqueKeepStrategy,
|
||||
MultiColSelector,
|
||||
MultiIndexSelector,
|
||||
PivotAgg,
|
||||
SingleIndexSelector,
|
||||
SizedMultiIndexSelector,
|
||||
SizedMultiNameSelector,
|
||||
SizeUnit,
|
||||
UniqueKeepStrategy,
|
||||
_2DArray,
|
||||
_SliceIndex,
|
||||
_SliceName,
|
||||
)
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
|
||||
|
||||
|
||||
class CompliantDataFrame(
|
||||
NumpyConvertible["_2DArray", "_2DArray"],
|
||||
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
|
||||
ArrowConvertible["pa.Table", "IntoArrowTable"],
|
||||
_StoresNative[NativeFrameT],
|
||||
FromNative[NativeFrameT],
|
||||
ToNarwhals[ToNarwhalsT_co],
|
||||
Sized,
|
||||
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
_native_frame: NativeFrameT
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self: ...
|
||||
def __narwhals_namespace__(self) -> Any: ...
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: Mapping[str, DType] | Schema | None,
|
||||
) -> Self: ...
|
||||
@classmethod
|
||||
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
|
||||
) -> Self: ...
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
|
||||
def __getitem__(
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
|
||||
MultiColSelector[CompliantSeriesT],
|
||||
],
|
||||
) -> Self: ...
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
"""`select` where all args are column names."""
|
||||
...
|
||||
|
||||
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
|
||||
"""`select` where all args are aggregations or literals.
|
||||
|
||||
(so, no broadcasting is necessary).
|
||||
"""
|
||||
# NOTE: Ignore is to avoid an intermittent false positive
|
||||
return self.select(*exprs) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
|
||||
@property
|
||||
def native(self) -> NativeFrameT:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def columns(self) -> Sequence[str]: ...
|
||||
@property
|
||||
def schema(self) -> Mapping[str, DType]: ...
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]: ...
|
||||
def clone(self) -> Self: ...
|
||||
def collect(
|
||||
self, backend: Implementation | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny: ...
|
||||
def collect_schema(self) -> Mapping[str, DType]: ...
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
|
||||
def estimated_size(self, unit: SizeUnit) -> int | float: ...
|
||||
def explode(self, columns: Sequence[str]) -> Self: ...
|
||||
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
|
||||
def gather_every(self, n: int, offset: int) -> Self: ...
|
||||
def get_column(self, name: str) -> CompliantSeriesT: ...
|
||||
def group_by(
|
||||
self,
|
||||
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> DataFrameGroupBy[Self, Any]: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def item(self, row: int | None, column: int | str | None) -> Any: ...
|
||||
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
|
||||
def iter_rows(
|
||||
self, *, named: bool, buffer_size: int
|
||||
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
|
||||
def is_unique(self) -> CompliantSeriesT: ...
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
|
||||
def pivot(
|
||||
self,
|
||||
on: Sequence[str],
|
||||
*,
|
||||
index: Sequence[str] | None,
|
||||
values: Sequence[str] | None,
|
||||
aggregate_function: PivotAgg | None,
|
||||
sort_columns: bool,
|
||||
separator: str,
|
||||
) -> Self: ...
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self: ...
|
||||
def row(self, index: int) -> tuple[Any, ...]: ...
|
||||
def rows(
|
||||
self, *, named: bool
|
||||
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self: ...
|
||||
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def sort(
|
||||
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
|
||||
) -> Self: ...
|
||||
def tail(self, n: int) -> Self: ...
|
||||
def to_arrow(self) -> pa.Table: ...
|
||||
def to_pandas(self) -> pd.DataFrame: ...
|
||||
def to_polars(self) -> pl.DataFrame: ...
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
|
||||
def unique(
|
||||
self,
|
||||
subset: Sequence[str] | None,
|
||||
*,
|
||||
keep: UniqueKeepStrategy,
|
||||
maintain_order: bool | None = None,
|
||||
) -> Self: ...
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self: ...
|
||||
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
|
||||
@overload
|
||||
def write_csv(self, file: None) -> str: ...
|
||||
@overload
|
||||
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
||||
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
|
||||
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
|
||||
it = (expr._evaluate_aliases(self) for expr in exprs)
|
||||
return list(chain.from_iterable(it))
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
||||
|
||||
|
||||
class CompliantLazyFrame(
|
||||
_StoresNative[NativeFrameT],
|
||||
FromNative[NativeFrameT],
|
||||
ToNarwhals[ToNarwhalsT_co],
|
||||
Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
_native_frame: NativeFrameT
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self: ...
|
||||
def __narwhals_namespace__(self) -> Any: ...
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
"""`select` where all args are column names."""
|
||||
...
|
||||
|
||||
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
|
||||
"""`select` where all args are aggregations or literals.
|
||||
|
||||
(so, no broadcasting is necessary).
|
||||
"""
|
||||
...
|
||||
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
|
||||
@property
|
||||
def native(self) -> NativeFrameT:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def columns(self) -> Sequence[str]: ...
|
||||
@property
|
||||
def schema(self) -> Mapping[str, DType]: ...
|
||||
def _iter_columns(self) -> Iterator[Any]: ...
|
||||
def collect(
|
||||
self, backend: Implementation | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny: ...
|
||||
def collect_schema(self) -> Mapping[str, DType]: ...
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
|
||||
def explode(self, columns: Sequence[str]) -> Self: ...
|
||||
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
|
||||
def group_by(
|
||||
self,
|
||||
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self: ...
|
||||
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
|
||||
def sort(
|
||||
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
|
||||
) -> Self: ...
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self: ...
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self: ...
|
||||
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
|
||||
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
|
||||
result = expr(self)
|
||||
assert len(result) == 1 # debug assertion # noqa: S101
|
||||
return result[0]
|
||||
|
||||
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
|
||||
it = (expr._evaluate_aliases(self) for expr in exprs)
|
||||
return list(chain.from_iterable(it))
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
||||
|
||||
|
||||
class EagerDataFrame(
|
||||
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
|
||||
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
|
||||
ValidateBackendVersion,
|
||||
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
|
||||
|
||||
def to_narwhals(self) -> DataFrame[NativeFrameT]:
|
||||
return self._version.dataframe(self, level="full")
|
||||
|
||||
def _with_native(
|
||||
self, df: NativeFrameT, *, validate_column_names: bool = True
|
||||
) -> Self: ...
|
||||
|
||||
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
|
||||
"""Evaluate `expr` and ensure it has a **single** output."""
|
||||
result: Sequence[EagerSeriesT] = expr(self)
|
||||
assert len(result) == 1 # debug assertion # noqa: S101
|
||||
return result[0]
|
||||
|
||||
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
|
||||
# NOTE: Ignore is to avoid an intermittent false positive
|
||||
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
|
||||
"""Return list of raw columns.
|
||||
|
||||
For eager backends we alias operations at each step.
|
||||
|
||||
As a safety precaution, here we can check that the expected result names match those
|
||||
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
|
||||
|
||||
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
|
||||
"""
|
||||
aliases = expr._evaluate_aliases(self)
|
||||
result = expr(self)
|
||||
if list(aliases) != (
|
||||
result_aliases := [s.name for s in result]
|
||||
): # pragma: no cover
|
||||
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
|
||||
raise AssertionError(msg)
|
||||
return result
|
||||
|
||||
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
|
||||
"""Extract native Series, broadcasting to `len(self)` if necessary."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _numpy_column_names(
|
||||
data: _2DArray, columns: Sequence[str] | None, /
|
||||
) -> list[str]:
|
||||
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
||||
def _select_multi_index(
|
||||
self, columns: SizedMultiIndexSelector[NativeSeriesT]
|
||||
) -> Self: ...
|
||||
def _select_multi_name(
|
||||
self, columns: SizedMultiNameSelector[NativeSeriesT]
|
||||
) -> Self: ...
|
||||
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
|
||||
def _select_slice_name(self, columns: _SliceName) -> Self: ...
|
||||
def __getitem__( # noqa: C901, PLR0912
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
|
||||
MultiColSelector[EagerSeriesT],
|
||||
],
|
||||
) -> Self:
|
||||
rows, columns = item
|
||||
compliant = self
|
||||
if not is_slice_none(columns):
|
||||
if isinstance(columns, Sized) and len(columns) == 0:
|
||||
return compliant.select()
|
||||
if is_index_selector(columns):
|
||||
if is_slice_index(columns) or is_range(columns):
|
||||
compliant = compliant._select_slice_index(columns)
|
||||
elif is_compliant_series(columns):
|
||||
compliant = self._select_multi_index(columns.native)
|
||||
else:
|
||||
compliant = compliant._select_multi_index(columns)
|
||||
elif isinstance(columns, slice):
|
||||
compliant = compliant._select_slice_name(columns)
|
||||
elif is_compliant_series(columns):
|
||||
compliant = self._select_multi_name(columns.native)
|
||||
elif is_sequence_like(columns):
|
||||
compliant = self._select_multi_name(columns)
|
||||
else:
|
||||
assert_never(columns)
|
||||
|
||||
if not is_slice_none(rows):
|
||||
if isinstance(rows, int):
|
||||
compliant = compliant._gather([rows])
|
||||
elif isinstance(rows, (slice, range)):
|
||||
compliant = compliant._gather_slice(rows)
|
||||
elif is_compliant_series(rows):
|
||||
compliant = compliant._gather(rows.native)
|
||||
elif is_sized_multi_index_selector(rows):
|
||||
compliant = compliant._gather(rows)
|
||||
else:
|
||||
assert_never(rows)
|
||||
|
||||
return compliant
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
return self.write_parquet(file)
|
||||
1169
.venv/lib/python3.12/site-packages/narwhals/_compliant/expr.py
Normal file
1169
.venv/lib/python3.12/site-packages/narwhals/_compliant/expr.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantDataFrameT_co,
|
||||
CompliantExprT_contra,
|
||||
CompliantFrameT_co,
|
||||
CompliantLazyFrameAny,
|
||||
DepthTrackingExprAny,
|
||||
DepthTrackingExprT_contra,
|
||||
EagerExprT_contra,
|
||||
NarwhalsAggregation,
|
||||
)
|
||||
from narwhals._utils import is_sequence_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
_SameFrameT = TypeVar("_SameFrameT", CompliantDataFrameAny, CompliantLazyFrameAny)
|
||||
|
||||
|
||||
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"]
|
||||
|
||||
NativeAggregationT_co = TypeVar(
|
||||
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
|
||||
)
|
||||
|
||||
|
||||
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
|
||||
|
||||
|
||||
class CompliantGroupBy(Protocol[CompliantFrameT_co, CompliantExprT_contra]):
|
||||
_compliant_frame: Any
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantFrameT_co:
|
||||
return self._compliant_frame # type: ignore[no-any-return]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compliant_frame: CompliantFrameT_co,
|
||||
keys: Sequence[CompliantExprT_contra] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None: ...
|
||||
|
||||
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
|
||||
|
||||
|
||||
class DataFrameGroupBy(
|
||||
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
|
||||
Protocol[CompliantDataFrameT_co, CompliantExprT_contra],
|
||||
):
|
||||
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
|
||||
|
||||
|
||||
class ParseKeysGroupBy(
|
||||
CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra],
|
||||
Protocol[CompliantFrameT_co, CompliantExprT_contra],
|
||||
):
|
||||
def _parse_keys(
|
||||
self,
|
||||
compliant_frame: _SameFrameT,
|
||||
keys: Sequence[CompliantExprT_contra] | Sequence[str],
|
||||
) -> tuple[_SameFrameT, list[str], list[str]]:
|
||||
if is_sequence_of(keys, str):
|
||||
keys_str = list(keys)
|
||||
return compliant_frame, keys_str, keys_str.copy()
|
||||
else:
|
||||
return self._parse_expr_keys(compliant_frame, keys=keys)
|
||||
|
||||
@staticmethod
|
||||
def _parse_expr_keys(
|
||||
compliant_frame: _SameFrameT, keys: Sequence[CompliantExprT_contra]
|
||||
) -> tuple[_SameFrameT, list[str], list[str]]:
|
||||
"""Parses key expressions to set up `.agg` operation with correct information.
|
||||
|
||||
Since keys are expressions, it's possible to alias any such key to match
|
||||
other dataframe column names.
|
||||
|
||||
In order to match polars behavior and not overwrite columns when evaluating keys:
|
||||
|
||||
- We evaluate what the output key names should be, in order to remap temporary column
|
||||
names to the expected ones, and to exclude those from unnamed expressions in
|
||||
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
|
||||
- Create temporary names for evaluated key expressions that are guaranteed to have
|
||||
no overlap with any existing column name.
|
||||
- Add these temporary columns to the compliant dataframe.
|
||||
"""
|
||||
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
|
||||
|
||||
def _temporary_name(key: str) -> str:
|
||||
# 5 is the length of `__tmp`
|
||||
key_str = str(key) # pandas allows non-string column names :sob:
|
||||
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
|
||||
|
||||
output_names = compliant_frame._evaluate_aliases(*keys)
|
||||
|
||||
safe_keys = [
|
||||
# multi-output expression cannot have duplicate names, hence it's safe to suffix
|
||||
key.name.map(_temporary_name)
|
||||
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
|
||||
# otherwise it's single named and we can use Expr.alias
|
||||
else key.alias(_temporary_name(new_name))
|
||||
for key, new_name in zip(keys, output_names)
|
||||
]
|
||||
return (
|
||||
compliant_frame.with_columns(*safe_keys),
|
||||
compliant_frame._evaluate_aliases(*safe_keys),
|
||||
output_names,
|
||||
)
|
||||
|
||||
|
||||
class DepthTrackingGroupBy(
|
||||
ParseKeysGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra],
|
||||
Protocol[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co],
|
||||
):
|
||||
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
|
||||
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
|
||||
"""Mapping from `narwhals` to native representation.
|
||||
|
||||
Note:
|
||||
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
|
||||
"""
|
||||
|
||||
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
|
||||
for expr in exprs:
|
||||
if not self._is_simple(expr):
|
||||
name = self.compliant._implementation.name.lower()
|
||||
msg = (
|
||||
f"Non-trivial complex aggregation found.\n\n"
|
||||
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
|
||||
f"{name!r} table.\n"
|
||||
"Please rewrite your query such that group-by aggregations "
|
||||
"are elementary. For example, instead of:\n\n"
|
||||
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
|
||||
"use:\n\n"
|
||||
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
|
||||
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
|
||||
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
|
||||
|
||||
@classmethod
|
||||
def _remap_expr_name(
|
||||
cls, name: NarwhalsAggregation | Any, /
|
||||
) -> NativeAggregationT_co:
|
||||
"""Replace `name`, with some native representation.
|
||||
|
||||
Arguments:
|
||||
name: Name of a `nw.Expr` aggregation method.
|
||||
|
||||
Returns:
|
||||
A native compatible representation.
|
||||
"""
|
||||
return cls._REMAP_AGGS.get(name, name)
|
||||
|
||||
@classmethod
|
||||
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
|
||||
"""Return the last function name in the chain defined by `expr`."""
|
||||
return _RE_LEAF_NAME.sub("", expr._function_name)
|
||||
|
||||
|
||||
class EagerGroupBy(
|
||||
DepthTrackingGroupBy[
|
||||
CompliantDataFrameT_co, EagerExprT_contra, NativeAggregationT_co
|
||||
],
|
||||
DataFrameGroupBy[CompliantDataFrameT_co, EagerExprT_contra],
|
||||
Protocol[CompliantDataFrameT_co, EagerExprT_contra, NativeAggregationT_co],
|
||||
): ...
|
||||
@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Protocol, overload
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT,
|
||||
CompliantFrameT,
|
||||
CompliantLazyFrameT,
|
||||
DepthTrackingExprT,
|
||||
EagerDataFrameT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
LazyExprT,
|
||||
NativeFrameT,
|
||||
NativeFrameT_co,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals._utils import (
|
||||
exclude_column_names,
|
||||
get_column_names,
|
||||
passthrough_column_names,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_2d
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Iterable, Mapping, Sequence
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.selectors import CompliantSelectorNamespace
|
||||
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
|
||||
from narwhals._utils import Implementation, Version
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.typing import (
|
||||
ConcatMethod,
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
NonNestedLiteral,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
__all__ = [
|
||||
"CompliantNamespace",
|
||||
"DepthTrackingNamespace",
|
||||
"EagerNamespace",
|
||||
"LazyNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
def all(self) -> CompliantExprT:
|
||||
return self._expr.from_column_names(get_column_names, context=self)
|
||||
|
||||
def col(self, *column_names: str) -> CompliantExprT:
|
||||
return self._expr.from_column_names(
|
||||
passthrough_column_names(column_names), context=self
|
||||
)
|
||||
|
||||
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
|
||||
return self._expr.from_column_names(
|
||||
partial(exclude_column_names, names=excluded_names), context=self
|
||||
)
|
||||
|
||||
def nth(self, *column_indices: int) -> CompliantExprT:
|
||||
return self._expr.from_column_indices(*column_indices, context=self)
|
||||
|
||||
def len(self) -> CompliantExprT: ...
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
|
||||
def all_horizontal(
|
||||
self, *exprs: CompliantExprT, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
def any_horizontal(
|
||||
self, *exprs: CompliantExprT, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def concat(
|
||||
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
|
||||
) -> CompliantFrameT: ...
|
||||
def when(
|
||||
self, predicate: CompliantExprT
|
||||
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
|
||||
def concat_str(
|
||||
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
|
||||
@property
|
||||
def _expr(self) -> type[CompliantExprT]: ...
|
||||
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
|
||||
|
||||
class DepthTrackingNamespace(
|
||||
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
|
||||
Protocol[CompliantFrameT, DepthTrackingExprT],
|
||||
):
|
||||
def all(self) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
get_column_names, function_name="all", context=self
|
||||
)
|
||||
|
||||
def col(self, *column_names: str) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
passthrough_column_names(column_names), function_name="col", context=self
|
||||
)
|
||||
|
||||
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
partial(exclude_column_names, names=excluded_names),
|
||||
function_name="exclude",
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class LazyNamespace(
|
||||
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
|
||||
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
|
||||
|
||||
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
|
||||
if self._lazyframe._is_native(data):
|
||||
return self._lazyframe.from_native(data, context=self)
|
||||
else: # pragma: no cover
|
||||
msg = f"Unsupported type: {type(data).__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
class EagerNamespace(
|
||||
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
|
||||
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[EagerDataFrameT]: ...
|
||||
@property
|
||||
def _series(self) -> type[EagerSeriesT]: ...
|
||||
def when(
|
||||
self, predicate: EagerExprT
|
||||
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
|
||||
|
||||
@overload
|
||||
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
|
||||
@overload
|
||||
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
|
||||
def from_native(
|
||||
self, data: NativeFrameT | NativeSeriesT | Any, /
|
||||
) -> EagerDataFrameT | EagerSeriesT:
|
||||
if self._dataframe._is_native(data):
|
||||
return self._dataframe.from_native(data, context=self)
|
||||
elif self._series._is_native(data):
|
||||
return self._series.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
@overload
|
||||
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
|
||||
|
||||
@overload
|
||||
def from_numpy(
|
||||
self,
|
||||
data: _2DArray,
|
||||
/,
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
|
||||
) -> EagerDataFrameT: ...
|
||||
|
||||
def from_numpy(
|
||||
self,
|
||||
data: Into1DArray | _2DArray,
|
||||
/,
|
||||
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
|
||||
) -> EagerDataFrameT | EagerSeriesT:
|
||||
if is_numpy_array_2d(data):
|
||||
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
||||
return self._series.from_numpy(data, context=self)
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
||||
def _concat_horizontal(
|
||||
self, dfs: Sequence[NativeFrameT | Any], /
|
||||
) -> NativeFrameT: ...
|
||||
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
||||
def concat(
|
||||
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
|
||||
) -> EagerDataFrameT:
|
||||
dfs = [item.native for item in items]
|
||||
if how == "horizontal":
|
||||
native = self._concat_horizontal(dfs)
|
||||
elif how == "vertical":
|
||||
native = self._concat_vertical(dfs)
|
||||
elif how == "diagonal":
|
||||
native = self._concat_diagonal(dfs)
|
||||
else: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
return self._dataframe.from_native(native, context=self)
|
||||
@ -0,0 +1,311 @@
|
||||
"""Almost entirely complete, generic `selectors` implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
|
||||
|
||||
from narwhals._compliant.expr import CompliantExpr
|
||||
from narwhals._utils import (
|
||||
_parse_time_unit_and_time_zone,
|
||||
dtype_matches_time_unit_and_time_zone,
|
||||
get_column_names,
|
||||
is_compliant_dataframe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from datetime import timezone
|
||||
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.expr import NativeExpr
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantLazyFrameAny,
|
||||
CompliantSeriesAny,
|
||||
CompliantSeriesOrNativeExprAny,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
)
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
__all__ = [
|
||||
"CompliantSelector",
|
||||
"CompliantSelectorNamespace",
|
||||
"EagerSelectorNamespace",
|
||||
"LazySelectorNamespace",
|
||||
]
|
||||
|
||||
|
||||
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
|
||||
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
|
||||
ExprT = TypeVar("ExprT", bound="NativeExpr")
|
||||
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
|
||||
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
|
||||
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
|
||||
SelectorOrExpr: TypeAlias = (
|
||||
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
|
||||
)
|
||||
|
||||
|
||||
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_namespace(cls, context: _LimitedContext, /) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
@property
|
||||
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
|
||||
|
||||
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
|
||||
|
||||
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
|
||||
|
||||
def _iter_columns_dtypes(
|
||||
self, df: FrameT, /
|
||||
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
|
||||
|
||||
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
|
||||
yield from zip(self._iter_columns(df), df.columns)
|
||||
|
||||
def _is_dtype(
|
||||
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [
|
||||
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def by_dtype(
|
||||
self, dtypes: Collection[DType | type[DType]]
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
p = re.compile(pattern)
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
if (
|
||||
is_compliant_dataframe(df)
|
||||
and not self._implementation.is_duckdb()
|
||||
and not self._implementation.is_ibis()
|
||||
):
|
||||
return [df.get_column(col) for col in df.columns if p.search(col)]
|
||||
|
||||
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [col for col in df.columns if p.search(col)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.Categorical)
|
||||
|
||||
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.String)
|
||||
|
||||
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.Boolean)
|
||||
|
||||
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return list(self._iter_columns(df))
|
||||
|
||||
return self._selector.from_callables(series, get_column_names, context=self)
|
||||
|
||||
def datetime(
|
||||
self,
|
||||
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
||||
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
|
||||
matches = partial(
|
||||
dtype_matches_time_unit_and_time_zone,
|
||||
dtypes=self._version.dtypes,
|
||||
time_units=time_units,
|
||||
time_zones=time_zones,
|
||||
)
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if matches(tp)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
|
||||
class EagerSelectorNamespace(
|
||||
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
|
||||
):
|
||||
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
|
||||
for ser in self._iter_columns(df):
|
||||
yield ser.name, ser.dtype
|
||||
|
||||
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
|
||||
yield from df.iter_columns()
|
||||
|
||||
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
|
||||
for ser in self._iter_columns(df):
|
||||
yield ser, ser.dtype
|
||||
|
||||
|
||||
class LazySelectorNamespace(
|
||||
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
|
||||
):
|
||||
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
|
||||
yield from df.schema.items()
|
||||
|
||||
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
|
||||
yield from df._iter_columns()
|
||||
|
||||
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
|
||||
yield from zip(self._iter_columns(df), df.schema.values())
|
||||
|
||||
|
||||
class CompliantSelector(
|
||||
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
|
||||
):
|
||||
_call: EvalSeries[FrameT, SeriesOrExprT]
|
||||
_function_name: str
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_callables(
|
||||
cls,
|
||||
call: EvalSeries[FrameT, SeriesOrExprT],
|
||||
evaluate_output_names: EvalNames[FrameT],
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = call
|
||||
obj._evaluate_output_names = evaluate_output_names
|
||||
obj._alias_output_names = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
|
||||
return self.__narwhals_namespace__().selectors
|
||||
|
||||
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
|
||||
def _is_selector(
|
||||
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
|
||||
return isinstance(other, type(self))
|
||||
|
||||
@overload
|
||||
def __sub__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __sub__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __sub__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
x for x, name in zip(self(df), lhs_names) if name not in rhs_names
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x in lhs_names if x not in rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() - other
|
||||
|
||||
@overload
|
||||
def __or__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __or__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __or__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
*(x for x, name in zip(self(df), lhs_names) if name not in rhs_names),
|
||||
*other(df),
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() | other
|
||||
|
||||
@overload
|
||||
def __and__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __and__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __and__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x, name in zip(self(df), lhs_names) if name in rhs_names]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x in lhs_names if x in rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() & other
|
||||
|
||||
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self.selectors.all() - self
|
||||
|
||||
|
||||
def _eval_lhs_rhs(
|
||||
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
|
||||
) -> tuple[Sequence[str], Sequence[str]]:
|
||||
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)
|
||||
516
.venv/lib/python3.12/site-packages/narwhals/_compliant/series.py
Normal file
516
.venv/lib/python3.12/site-packages/narwhals/_compliant/series.py
Normal file
@ -0,0 +1,516 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol
|
||||
|
||||
from narwhals._compliant.any_namespace import (
|
||||
CatNamespace,
|
||||
DateTimeNamespace,
|
||||
ListNamespace,
|
||||
StringNamespace,
|
||||
StructNamespace,
|
||||
)
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantSeriesT_co,
|
||||
EagerDataFrameAny,
|
||||
EagerSeriesT_co,
|
||||
NativeSeriesT,
|
||||
NativeSeriesT_co,
|
||||
)
|
||||
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
|
||||
from narwhals._typing_compat import TypeVar, assert_never
|
||||
from narwhals._utils import (
|
||||
_StoresCompliant,
|
||||
_StoresNative,
|
||||
is_compliant_series,
|
||||
is_sized_multi_index_selector,
|
||||
unstable,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
|
||||
from narwhals._compliant.dataframe import CompliantDataFrame
|
||||
from narwhals._compliant.expr import CompliantExpr, EagerExpr
|
||||
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
ClosedInterval,
|
||||
FillNullStrategy,
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
MultiIndexSelector,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RankMethod,
|
||||
RollingInterpolationMethod,
|
||||
SizedMultiIndexSelector,
|
||||
TemporalLiteral,
|
||||
_1DArray,
|
||||
_SliceIndex,
|
||||
)
|
||||
|
||||
class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
|
||||
breakpoint: NotRequired[list[float] | _1DArray | list[Any]]
|
||||
count: NativeSeriesT | _1DArray | _CountsT_co | list[Any]
|
||||
|
||||
|
||||
_CountsT_co = TypeVar("_CountsT_co", bound="Iterable[Any]", covariant=True)
|
||||
|
||||
__all__ = [
|
||||
"CompliantSeries",
|
||||
"EagerSeries",
|
||||
"EagerSeriesCatNamespace",
|
||||
"EagerSeriesDateTimeNamespace",
|
||||
"EagerSeriesHist",
|
||||
"EagerSeriesListNamespace",
|
||||
"EagerSeriesNamespace",
|
||||
"EagerSeriesStringNamespace",
|
||||
"EagerSeriesStructNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CompliantSeries(
|
||||
NumpyConvertible["_1DArray", "Into1DArray"],
|
||||
FromIterable,
|
||||
FromNative[NativeSeriesT],
|
||||
ToNarwhals["Series[NativeSeriesT]"],
|
||||
Protocol[NativeSeriesT],
|
||||
):
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def native(self) -> NativeSeriesT: ...
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
|
||||
def __native_namespace__(self) -> ModuleType: ...
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
|
||||
def __contains__(self, other: Any) -> bool: ...
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
|
||||
def __iter__(self) -> Iterator[Any]: ...
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def _with_native(self, series: Any) -> Self: ...
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
def _to_expr(self) -> CompliantExpr[Any, Self]: ...
|
||||
@classmethod
|
||||
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls,
|
||||
data: Iterable[Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
name: str = "",
|
||||
dtype: IntoDType | None = None,
|
||||
) -> Self: ...
|
||||
def to_narwhals(self) -> Series[NativeSeriesT]:
|
||||
return self._version.series(self, level="full")
|
||||
|
||||
# Operators
|
||||
def __add__(self, other: Any) -> Self: ...
|
||||
def __and__(self, other: Any) -> Self: ...
|
||||
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
|
||||
def __floordiv__(self, other: Any) -> Self: ...
|
||||
def __ge__(self, other: Any) -> Self: ...
|
||||
def __gt__(self, other: Any) -> Self: ...
|
||||
def __invert__(self) -> Self: ...
|
||||
def __le__(self, other: Any) -> Self: ...
|
||||
def __lt__(self, other: Any) -> Self: ...
|
||||
def __mod__(self, other: Any) -> Self: ...
|
||||
def __mul__(self, other: Any) -> Self: ...
|
||||
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
|
||||
def __or__(self, other: Any) -> Self: ...
|
||||
def __pow__(self, other: Any) -> Self: ...
|
||||
def __radd__(self, other: Any) -> Self: ...
|
||||
def __rand__(self, other: Any) -> Self: ...
|
||||
def __rfloordiv__(self, other: Any) -> Self: ...
|
||||
def __rmod__(self, other: Any) -> Self: ...
|
||||
def __rmul__(self, other: Any) -> Self: ...
|
||||
def __ror__(self, other: Any) -> Self: ...
|
||||
def __rpow__(self, other: Any) -> Self: ...
|
||||
def __rsub__(self, other: Any) -> Self: ...
|
||||
def __rtruediv__(self, other: Any) -> Self: ...
|
||||
def __sub__(self, other: Any) -> Self: ...
|
||||
def __truediv__(self, other: Any) -> Self: ...
|
||||
|
||||
def abs(self) -> Self: ...
|
||||
def alias(self, name: str) -> Self: ...
|
||||
def all(self) -> bool: ...
|
||||
def any(self) -> bool: ...
|
||||
def arg_max(self) -> int: ...
|
||||
def arg_min(self) -> int: ...
|
||||
def arg_true(self) -> Self: ...
|
||||
def cast(self, dtype: IntoDType) -> Self: ...
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self: ...
|
||||
def count(self) -> int: ...
|
||||
def cum_count(self, *, reverse: bool) -> Self: ...
|
||||
def cum_max(self, *, reverse: bool) -> Self: ...
|
||||
def cum_min(self, *, reverse: bool) -> Self: ...
|
||||
def cum_prod(self, *, reverse: bool) -> Self: ...
|
||||
def cum_sum(self, *, reverse: bool) -> Self: ...
|
||||
def diff(self) -> Self: ...
|
||||
def drop_nulls(self) -> Self: ...
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self: ...
|
||||
def exp(self) -> Self: ...
|
||||
def sqrt(self) -> Self: ...
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self: ...
|
||||
def filter(self, predicate: Any) -> Self: ...
|
||||
def gather_every(self, n: int, offset: int) -> Self: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def is_between(
|
||||
self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval
|
||||
) -> Self: ...
|
||||
def is_finite(self) -> Self: ...
|
||||
def is_first_distinct(self) -> Self: ...
|
||||
def is_in(self, other: Any) -> Self: ...
|
||||
def is_last_distinct(self) -> Self: ...
|
||||
def is_nan(self) -> Self: ...
|
||||
def is_null(self) -> Self: ...
|
||||
def is_sorted(self, *, descending: bool) -> bool: ...
|
||||
def is_unique(self) -> Self: ...
|
||||
def item(self, index: int | None) -> Any: ...
|
||||
def kurtosis(self) -> float | None: ...
|
||||
def len(self) -> int: ...
|
||||
def log(self, base: float) -> Self: ...
|
||||
def max(self) -> Any: ...
|
||||
def mean(self) -> float: ...
|
||||
def median(self) -> float: ...
|
||||
def min(self) -> Any: ...
|
||||
def mode(self) -> Self: ...
|
||||
def n_unique(self) -> int: ...
|
||||
def null_count(self) -> int: ...
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> float: ...
|
||||
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self: ...
|
||||
def rolling_mean(
|
||||
self, window_size: int, *, min_samples: int, center: bool
|
||||
) -> Self: ...
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self: ...
|
||||
def rolling_sum(
|
||||
self, window_size: int, *, min_samples: int, center: bool
|
||||
) -> Self: ...
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self: ...
|
||||
def round(self, decimals: int) -> Self: ...
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self: ...
|
||||
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
|
||||
def shift(self, n: int) -> Self: ...
|
||||
def skew(self) -> float | None: ...
|
||||
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
|
||||
def std(self, *, ddof: int) -> float: ...
|
||||
def sum(self) -> float: ...
|
||||
def tail(self, n: int) -> Self: ...
|
||||
def to_arrow(self) -> pa.Array[Any]: ...
|
||||
def to_dummies(
|
||||
self, *, separator: str, drop_first: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def to_list(self) -> list[Any]: ...
|
||||
def to_pandas(self) -> pd.Series[Any]: ...
|
||||
def to_polars(self) -> pl.Series: ...
|
||||
def unique(self, *, maintain_order: bool) -> Self: ...
|
||||
def value_counts(
|
||||
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def var(self, *, ddof: int) -> float: ...
|
||||
def zip_with(self, mask: Any, other: Any) -> Self: ...
|
||||
@unstable
|
||||
def hist_from_bins(
|
||||
self, bins: list[float], *, include_breakpoint: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]:
|
||||
"""`Series.hist(bins=..., bin_count=None)`."""
|
||||
...
|
||||
|
||||
@unstable
|
||||
def hist_from_bin_count(
|
||||
self, bin_count: int, *, include_breakpoint: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]:
|
||||
"""`Series.hist(bins=None, bin_count=...)`."""
|
||||
...
|
||||
|
||||
@property
|
||||
def str(self) -> Any: ...
|
||||
@property
|
||||
def dt(self) -> Any: ...
|
||||
@property
|
||||
def cat(self) -> Any: ...
|
||||
@property
|
||||
def list(self) -> Any: ...
|
||||
@property
|
||||
def struct(self) -> Any: ...
|
||||
|
||||
|
||||
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
|
||||
_native_series: Any
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
_broadcast: bool
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@classmethod
|
||||
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
|
||||
"""Ensure all of `series` have the same length (and index if `pandas`).
|
||||
|
||||
Scalars get broadcasted to the full length of the longest Series.
|
||||
|
||||
This is useful when you need to construct a full Series anyway, such as:
|
||||
|
||||
DataFrame.select(...)
|
||||
|
||||
It should not be used in binary operations, such as:
|
||||
|
||||
nw.col("a") - nw.col("a").mean()
|
||||
|
||||
because then it's more efficient to extract the right-hand-side's single element as a scalar.
|
||||
"""
|
||||
...
|
||||
|
||||
def _from_scalar(self, value: Any) -> Self:
|
||||
return self.from_iterable([value], name=self.name, context=self)
|
||||
|
||||
def _with_native(
|
||||
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
|
||||
) -> Self:
|
||||
"""Return a new `CompliantSeries`, wrapping the native `series`.
|
||||
|
||||
In cases when operations are known to not affect whether a result should
|
||||
be broadcast, we can pass `preserve_broadcast=True`.
|
||||
Set this with care - it should only be set for unary expressions which don't
|
||||
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
|
||||
set it, you probably don't need it.
|
||||
"""
|
||||
...
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
|
||||
|
||||
def _to_expr(self) -> EagerExpr[Any, Any]:
|
||||
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
|
||||
if isinstance(item, (slice, range)):
|
||||
return self._gather_slice(item)
|
||||
elif is_compliant_series(item):
|
||||
return self._gather(item.native)
|
||||
elif is_sized_multi_index_selector(item):
|
||||
return self._gather(item)
|
||||
else:
|
||||
assert_never(item)
|
||||
|
||||
@property
|
||||
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
|
||||
|
||||
|
||||
class _SeriesNamespace( # type: ignore[misc]
|
||||
_StoresCompliant[CompliantSeriesT_co],
|
||||
_StoresNative[NativeSeriesT_co],
|
||||
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
|
||||
):
|
||||
_compliant_series: CompliantSeriesT_co
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantSeriesT_co:
|
||||
return self._compliant_series
|
||||
|
||||
@property
|
||||
def implementation(self) -> Implementation:
|
||||
return self.compliant._implementation
|
||||
|
||||
@property
|
||||
def backend_version(self) -> tuple[int, ...]:
|
||||
return self.implementation._backend_version()
|
||||
|
||||
@property
|
||||
def version(self) -> Version:
|
||||
return self.compliant._version
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT_co:
|
||||
return self._compliant_series.native # type: ignore[no-any-return]
|
||||
|
||||
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
|
||||
return self.compliant._with_native(series)
|
||||
|
||||
|
||||
class EagerSeriesNamespace(
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
Generic[EagerSeriesT_co, NativeSeriesT_co],
|
||||
):
|
||||
_compliant_series: EagerSeriesT_co
|
||||
|
||||
def __init__(self, series: EagerSeriesT_co, /) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
|
||||
class EagerSeriesCatNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
CatNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
DateTimeNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesListNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
ListNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesStringNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
StringNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesStructNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
StructNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesHist(Protocol[NativeSeriesT, _CountsT_co]):
|
||||
_series: EagerSeries[NativeSeriesT]
|
||||
_breakpoint: bool
|
||||
_data: HistData[NativeSeriesT, _CountsT_co]
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT:
|
||||
return self._series.native
|
||||
|
||||
@classmethod
|
||||
def from_series(
|
||||
cls, series: EagerSeries[NativeSeriesT], *, include_breakpoint: bool
|
||||
) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._series = series
|
||||
obj._breakpoint = include_breakpoint
|
||||
return obj
|
||||
|
||||
def to_frame(self) -> EagerDataFrameAny: ...
|
||||
def _linear_space( # NOTE: Roughly `pl.linear_space`
|
||||
self,
|
||||
start: float,
|
||||
end: float,
|
||||
num_samples: int,
|
||||
*,
|
||||
closed: Literal["both", "none"] = "both",
|
||||
) -> _1DArray: ...
|
||||
|
||||
# NOTE: *Could* be handled at narwhals-level
|
||||
def is_empty_series(self) -> bool: ...
|
||||
|
||||
# NOTE: **Should** be handled at narwhals-level
|
||||
def data_empty(self) -> HistData[NativeSeriesT, _CountsT_co]:
|
||||
return {"breakpoint": [], "count": []} if self._breakpoint else {"count": []}
|
||||
|
||||
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
|
||||
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
|
||||
def series_empty(
|
||||
self, arg: int | list[float], /
|
||||
) -> HistData[NativeSeriesT, _CountsT_co]: ...
|
||||
|
||||
def with_bins(self, bins: list[float], /) -> Self:
|
||||
if len(bins) <= 1:
|
||||
self._data = self.data_empty()
|
||||
elif self.is_empty_series():
|
||||
self._data = self.series_empty(bins)
|
||||
else:
|
||||
self._data = self._calculate_hist(bins)
|
||||
return self
|
||||
|
||||
def with_bin_count(self, bin_count: int, /) -> Self:
|
||||
if bin_count == 0:
|
||||
self._data = self.data_empty()
|
||||
elif self.is_empty_series():
|
||||
self._data = self.series_empty(bin_count)
|
||||
else:
|
||||
self._data = self._calculate_hist(self._calculate_bins(bin_count))
|
||||
return self
|
||||
|
||||
def _calculate_breakpoint(self, arg: int | list[float], /) -> list[float] | _1DArray:
|
||||
bins = self._linear_space(0, 1, arg + 1) if isinstance(arg, int) else arg
|
||||
return bins[1:]
|
||||
|
||||
def _calculate_bins(self, bin_count: int) -> _1DArray: ...
|
||||
def _calculate_hist(
|
||||
self, bins: list[float] | _1DArray
|
||||
) -> HistData[NativeSeriesT, _CountsT_co]: ...
|
||||
196
.venv/lib/python3.12/site-packages/narwhals/_compliant/typing.py
Normal file
196
.venv/lib/python3.12/site-packages/narwhals/_compliant/typing.py
Normal file
@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.dataframe import (
|
||||
CompliantDataFrame,
|
||||
CompliantLazyFrame,
|
||||
EagerDataFrame,
|
||||
)
|
||||
from narwhals._compliant.expr import (
|
||||
CompliantExpr,
|
||||
DepthTrackingExpr,
|
||||
EagerExpr,
|
||||
LazyExpr,
|
||||
NativeExpr,
|
||||
)
|
||||
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
|
||||
from narwhals._compliant.series import CompliantSeries, EagerSeries
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
NativeFrame,
|
||||
NativeSeries,
|
||||
RankMethod,
|
||||
RollingInterpolationMethod,
|
||||
)
|
||||
|
||||
class ScalarKwargs(TypedDict, total=False):
|
||||
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
|
||||
|
||||
adjust: bool
|
||||
alpha: float | None
|
||||
center: int
|
||||
com: float | None
|
||||
ddof: int
|
||||
descending: bool
|
||||
half_life: float | None
|
||||
ignore_nulls: bool
|
||||
interpolation: RollingInterpolationMethod
|
||||
limit: int | None
|
||||
method: RankMethod
|
||||
min_samples: int
|
||||
n: int
|
||||
quantile: float
|
||||
reverse: bool
|
||||
span: float | None
|
||||
strategy: FillNullStrategy | None
|
||||
window_size: int
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AliasName",
|
||||
"AliasNames",
|
||||
"CompliantDataFrameT",
|
||||
"CompliantFrameT",
|
||||
"CompliantLazyFrameT",
|
||||
"CompliantSeriesT",
|
||||
"EvalNames",
|
||||
"EvalSeries",
|
||||
"IntoCompliantExpr",
|
||||
"NarwhalsAggregation",
|
||||
"NativeFrameT_co",
|
||||
"NativeSeriesT_co",
|
||||
]
|
||||
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
|
||||
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
|
||||
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
|
||||
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
|
||||
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
|
||||
CompliantFrameAny: TypeAlias = "CompliantDataFrameAny | CompliantLazyFrameAny"
|
||||
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
|
||||
|
||||
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
|
||||
|
||||
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
|
||||
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
|
||||
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
|
||||
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
|
||||
|
||||
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
|
||||
|
||||
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
|
||||
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
|
||||
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
|
||||
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
|
||||
NativeSeriesT_contra = TypeVar(
|
||||
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
|
||||
)
|
||||
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
|
||||
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
|
||||
NativeFrameT_contra = TypeVar(
|
||||
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
|
||||
)
|
||||
|
||||
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
|
||||
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
|
||||
CompliantExprT_contra = TypeVar(
|
||||
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
|
||||
)
|
||||
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
|
||||
CompliantSeriesT_co = TypeVar(
|
||||
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
|
||||
)
|
||||
CompliantSeriesOrNativeExprT = TypeVar(
|
||||
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
|
||||
)
|
||||
CompliantSeriesOrNativeExprT_co = TypeVar(
|
||||
"CompliantSeriesOrNativeExprT_co",
|
||||
bound=CompliantSeriesOrNativeExprAny,
|
||||
covariant=True,
|
||||
)
|
||||
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
|
||||
CompliantFrameT_co = TypeVar(
|
||||
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
|
||||
)
|
||||
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
|
||||
CompliantDataFrameT_co = TypeVar(
|
||||
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
|
||||
)
|
||||
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
|
||||
CompliantLazyFrameT_co = TypeVar(
|
||||
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
|
||||
)
|
||||
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
|
||||
CompliantNamespaceT_co = TypeVar(
|
||||
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
|
||||
)
|
||||
|
||||
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
|
||||
|
||||
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
|
||||
DepthTrackingExprT_contra = TypeVar(
|
||||
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
|
||||
)
|
||||
|
||||
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
|
||||
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
|
||||
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
|
||||
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
|
||||
|
||||
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
|
||||
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
|
||||
|
||||
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
||||
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
|
||||
|
||||
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
|
||||
"""A function aliasing a *sequence* of column names."""
|
||||
|
||||
AliasName: TypeAlias = Callable[[str], str]
|
||||
"""A function aliasing a *single* column name."""
|
||||
|
||||
EvalSeries: TypeAlias = Callable[
|
||||
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
|
||||
]
|
||||
"""A function from a `Frame` to a sequence of `Series`*.
|
||||
|
||||
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
|
||||
"""
|
||||
|
||||
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
|
||||
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
|
||||
|
||||
WindowFunction: TypeAlias = (
|
||||
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
|
||||
)
|
||||
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
|
||||
|
||||
NarwhalsAggregation: TypeAlias = Literal[
|
||||
"sum",
|
||||
"mean",
|
||||
"median",
|
||||
"max",
|
||||
"min",
|
||||
"std",
|
||||
"var",
|
||||
"len",
|
||||
"n_unique",
|
||||
"count",
|
||||
"quantile",
|
||||
"all",
|
||||
"any",
|
||||
]
|
||||
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
|
||||
|
||||
Be sure to update me if you're working on one of these:
|
||||
- https://github.com/narwhals-dev/narwhals/issues/981
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2385
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2484
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2526
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2660
|
||||
"""
|
||||
@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
from narwhals._compliant.expr import CompliantExpr
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantSeriesOrNativeExprAny,
|
||||
EagerDataFrameT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
LazyExprAny,
|
||||
NativeSeriesT,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import EvalSeries
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
|
||||
|
||||
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
|
||||
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
||||
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
|
||||
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
|
||||
|
||||
Scalar: TypeAlias = Any
|
||||
"""A native literal value."""
|
||||
|
||||
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
|
||||
"""Anything that is convertible into a `CompliantExpr`."""
|
||||
|
||||
|
||||
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
|
||||
_condition: ExprT
|
||||
_then_value: IntoExpr[SeriesT, ExprT]
|
||||
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
|
||||
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
|
||||
def then(
|
||||
self, value: IntoExpr[SeriesT, ExprT], /
|
||||
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
|
||||
return self._then.from_when(self, value)
|
||||
|
||||
@classmethod
|
||||
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._condition = condition
|
||||
obj._then_value = None
|
||||
obj._otherwise_value = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
|
||||
WhenT_contra = TypeVar(
|
||||
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
|
||||
)
|
||||
|
||||
|
||||
class CompliantThen(
|
||||
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
|
||||
):
|
||||
_call: EvalSeries[FrameT, SeriesT]
|
||||
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
|
||||
when._then_value = then
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = when
|
||||
obj._when_value = when
|
||||
obj._evaluate_output_names = getattr(
|
||||
then, "_evaluate_output_names", lambda _df: ["literal"]
|
||||
)
|
||||
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
||||
obj._implementation = when._implementation
|
||||
obj._version = when._version
|
||||
return obj
|
||||
|
||||
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
|
||||
self._when_value._otherwise_value = otherwise
|
||||
return cast("ExprT", self)
|
||||
|
||||
|
||||
class EagerWhen(
|
||||
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
|
||||
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
|
||||
):
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: NativeSeriesT,
|
||||
then: NativeSeriesT,
|
||||
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
|
||||
/,
|
||||
) -> NativeSeriesT: ...
|
||||
|
||||
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
|
||||
is_expr = self._condition._is_expr
|
||||
when: EagerSeriesT = self._condition(df)[0]
|
||||
then: EagerSeriesT
|
||||
align = when._align_full_broadcast
|
||||
|
||||
if is_expr(self._then_value):
|
||||
then = self._then_value(df)[0]
|
||||
else:
|
||||
then = when.alias("literal")._from_scalar(self._then_value)
|
||||
then._broadcast = True
|
||||
|
||||
if is_expr(self._otherwise_value):
|
||||
otherwise = self._otherwise_value(df)[0]
|
||||
when, then, otherwise = align(when, then, otherwise)
|
||||
result = self._if_then_else(when.native, then.native, otherwise.native)
|
||||
else:
|
||||
when, then = align(when, then)
|
||||
result = self._if_then_else(when.native, then.native, self._otherwise_value)
|
||||
return [then._with_native(result)]
|
||||
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
from narwhals._compliant.typing import NativeExprT_co
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
__all__ = ["WindowInputs"]
|
||||
|
||||
|
||||
class WindowInputs(Generic[NativeExprT_co]):
|
||||
__slots__ = ("order_by", "partition_by")
|
||||
|
||||
def __init__(
|
||||
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
|
||||
) -> None:
|
||||
self.partition_by = partition_by
|
||||
self.order_by = order_by
|
||||
30
.venv/lib/python3.12/site-packages/narwhals/_constants.py
Normal file
30
.venv/lib/python3.12/site-packages/narwhals/_constants.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
|
||||
# Temporal (from `polars._utils.constants`)
|
||||
SECONDS_PER_DAY = 86_400
|
||||
SECONDS_PER_MINUTE = 60
|
||||
NS_PER_MINUTE = 60_000_000_000
|
||||
"""Nanoseconds (`[ns]`) per minute."""
|
||||
US_PER_MINUTE = 60_000_000
|
||||
"""Microseconds (`[μs]`) per minute."""
|
||||
MS_PER_MINUTE = 60_000
|
||||
"""Milliseconds (`[ms]`) per minute."""
|
||||
NS_PER_SECOND = 1_000_000_000
|
||||
"""Nanoseconds (`[ns]`) per second (`[s]`)."""
|
||||
US_PER_SECOND = 1_000_000
|
||||
"""Microseconds (`[μs]`) per second (`[s]`)."""
|
||||
MS_PER_SECOND = 1_000
|
||||
"""Milliseconds (`[ms]`) per second (`[s]`)."""
|
||||
NS_PER_MICROSECOND = 1_000
|
||||
"""Nanoseconds (`[ns]`) per microsecond (`[μs]`)."""
|
||||
NS_PER_MILLISECOND = 1_000_000
|
||||
"""Nanoseconds (`[ns]`) per millisecond (`[ms]`).
|
||||
|
||||
From [polars](https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-time/src/chunkedarray/duration.rs#L7).
|
||||
"""
|
||||
EPOCH_YEAR = 1970
|
||||
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""
|
||||
EPOCH = dt.datetime(EPOCH_YEAR, 1, 1).replace(tzinfo=None)
|
||||
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
483
.venv/lib/python3.12/site-packages/narwhals/_dask/dataframe.py
Normal file
483
.venv/lib/python3.12/site-packages/narwhals/_dask/dataframe.py
Normal file
@ -0,0 +1,483 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._dask.utils import add_row_index, evaluate_exprs
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
_remap_full_join_keys,
|
||||
check_column_names_are_unique,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
)
|
||||
from narwhals.typing import CompliantLazyFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
Incomplete: TypeAlias = "Any"
|
||||
"""Using `_pandas_like` utils with `_dask`.
|
||||
|
||||
Typing this correctly will complicate the `_pandas_like`-side.
|
||||
Very low priority until `dask` adds typing.
|
||||
"""
|
||||
|
||||
|
||||
class DaskLazyFrame(
|
||||
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: dd.DataFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: dd.DataFrame = native_dataframe
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
|
||||
return isinstance(obj, dd.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.DASK:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace:
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: Any) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _iter_columns(self) -> Iterator[dx.Series]:
|
||||
for _col, ser in self.native.items(): # noqa: PERF102
|
||||
yield ser
|
||||
|
||||
def with_columns(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
return self._with_native(self.native.assign(**dict(new_series)))
|
||||
|
||||
def collect(
|
||||
self, backend: Implementation | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
result = self.native.compute(**kwargs)
|
||||
|
||||
if backend is None or backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result,
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
pl.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
pa.Table.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else self.native.columns.tolist()
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def filter(self, predicate: DaskExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = predicate(self)[0]
|
||||
return self._with_native(self.native.loc[mask])
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
df: Incomplete = self.native
|
||||
native = select_columns_by_name(df, list(column_names), self._implementation)
|
||||
return self._with_native(native)
|
||||
|
||||
def aggregate(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
|
||||
return self._with_native(df)
|
||||
|
||||
def select(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df: Incomplete = self.native
|
||||
df = select_columns_by_name(
|
||||
df.assign(**dict(new_series)),
|
||||
[s[0] for s in new_series],
|
||||
self._implementation,
|
||||
)
|
||||
return self._with_native(df)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.dropna())
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
native_dtypes = self.native.dtypes
|
||||
self._cached_schema = {
|
||||
col: native_to_narwhals_dtype(
|
||||
native_dtypes[col], self._version, self._implementation
|
||||
)
|
||||
for col in self.native.columns
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
|
||||
return self._with_native(self.native.drop(columns=to_drop))
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
# Implementation is based on the following StackOverflow reply:
|
||||
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
|
||||
if order_by is None:
|
||||
return self._with_native(add_row_index(self.native, name))
|
||||
else:
|
||||
plx = self.__narwhals_namespace__()
|
||||
columns = self.columns
|
||||
const_expr = (
|
||||
plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
|
||||
)
|
||||
row_index_expr = (
|
||||
plx.col(name)
|
||||
.cum_sum(reverse=False)
|
||||
.over(partition_by=[], order_by=order_by)
|
||||
- 1
|
||||
)
|
||||
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
return self._with_native(self.native.rename(columns=mapping))
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
if keep == "none":
|
||||
subset = subset or self.columns
|
||||
token = generate_temporary_column_name(n_bytes=8, columns=subset)
|
||||
ser = self.native.groupby(subset).size().rename(token)
|
||||
ser = ser[ser == 1]
|
||||
unique = ser.reset_index().drop(columns=token)
|
||||
result = self.native.merge(unique, on=subset, how="inner")
|
||||
else:
|
||||
mapped_keep = {"any": "first"}.get(keep, keep)
|
||||
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
|
||||
return self._with_native(result)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
ascending: bool | list[bool] = not descending
|
||||
else:
|
||||
ascending = [not d for d in descending]
|
||||
position = "last" if nulls_last else "first"
|
||||
return self._with_native(
|
||||
self.native.sort_values(list(by), ascending=ascending, na_position=position)
|
||||
)
|
||||
|
||||
def _join_inner(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
return self.native.merge(
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
how="inner",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_left(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
result_native = self.native.merge(
|
||||
other.native,
|
||||
how="left",
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
extra = [
|
||||
right_key if right_key not in self.columns else f"{right_key}{suffix}"
|
||||
for left_key, right_key in zip(left_on, right_on)
|
||||
if right_key != left_key
|
||||
]
|
||||
return result_native.drop(columns=extra)
|
||||
|
||||
def _join_full(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
# dask does not retain keys post-join
|
||||
# we must append the suffix to each key before-hand
|
||||
|
||||
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
|
||||
other_native = other.native.rename(columns=right_on_mapper)
|
||||
check_column_names_are_unique(other_native.columns)
|
||||
right_suffixed = list(right_on_mapper.values())
|
||||
return self.native.merge(
|
||||
other_native,
|
||||
left_on=left_on,
|
||||
right_on=right_suffixed,
|
||||
how="outer",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
return (
|
||||
self.native.assign(**{key_token: 0})
|
||||
.merge(
|
||||
other.native.assign(**{key_token: 0}),
|
||||
how="inner",
|
||||
left_on=key_token,
|
||||
right_on=key_token,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
.drop(columns=key_token)
|
||||
)
|
||||
|
||||
def _join_semi(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
return self.native.merge(
|
||||
other_native, how="inner", left_on=left_on, right_on=left_on
|
||||
)
|
||||
|
||||
def _join_anti(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
indicator_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
df = self.native.merge(
|
||||
other_native,
|
||||
how="left",
|
||||
indicator=indicator_token, # pyright: ignore[reportArgumentType]
|
||||
left_on=left_on,
|
||||
right_on=left_on,
|
||||
)
|
||||
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
|
||||
|
||||
def _join_filter_rename(
|
||||
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
|
||||
) -> dd.DataFrame:
|
||||
"""Helper function to avoid creating extra columns and row duplication.
|
||||
|
||||
Used in `"anti"` and `"semi`" join's.
|
||||
|
||||
Notice that a native object is returned.
|
||||
"""
|
||||
other_native: Incomplete = other.native
|
||||
# rename to avoid creating extra columns in join
|
||||
return (
|
||||
select_columns_by_name(other_native, columns_to_select, self._implementation)
|
||||
.rename(columns=columns_mapping)
|
||||
.drop_duplicates()
|
||||
)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
if how == "cross":
|
||||
result = self._join_cross(other=other, suffix=suffix)
|
||||
|
||||
elif left_on is None or right_on is None: # pragma: no cover
|
||||
raise ValueError(left_on, right_on)
|
||||
|
||||
elif how == "inner":
|
||||
result = self._join_inner(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "anti":
|
||||
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "semi":
|
||||
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "left":
|
||||
result = self._join_left(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "full":
|
||||
result = self._join_full(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
else:
|
||||
assert_never(how)
|
||||
return self._with_native(result)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
plx = self.__native_namespace__()
|
||||
return self._with_native(
|
||||
plx.merge_asof(
|
||||
self.native,
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
left_by=by_left,
|
||||
right_by=by_right,
|
||||
direction=strategy,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
|
||||
) -> DaskLazyGroupBy:
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
|
||||
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def tail(self, n: int) -> Self: # pragma: no cover
|
||||
native_frame = self.native
|
||||
n_partitions = native_frame.npartitions
|
||||
|
||||
if n_partitions == 1:
|
||||
return self._with_native(self.native.tail(n=n, compute=False))
|
||||
else:
|
||||
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
plx = self.__narwhals_namespace__()
|
||||
return (
|
||||
self.with_row_index(row_index_token, order_by=None)
|
||||
.filter(
|
||||
(plx.col(row_index_token) >= offset)
|
||||
& ((plx.col(row_index_token) - offset) % n == 0)
|
||||
)
|
||||
.drop([row_index_token], strict=False)
|
||||
)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
var_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
self.native.to_parquet(file)
|
||||
|
||||
explode = not_implemented()
|
||||
692
.venv/lib/python3.12/site-packages/narwhals/_dask/expr.py
Normal file
692
.venv/lib/python3.12/site-packages/narwhals/_dask/expr.py
Normal file
@ -0,0 +1,692 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
from narwhals._compliant import DepthTrackingExpr, LazyExpr
|
||||
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
|
||||
from narwhals._dask.expr_str import DaskExprStringNamespace
|
||||
from narwhals._dask.utils import (
|
||||
add_row_index,
|
||||
maybe_evaluate_expr,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
|
||||
from narwhals._pandas_like.utils import native_to_narwhals_dtype
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RollingInterpolationMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
|
||||
class DaskExpr(
|
||||
LazyExpr["DaskLazyFrame", "dx.Series"],
|
||||
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DaskLazyFrame, dx.Series],
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[DaskLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self._call(df)
|
||||
|
||||
def __narwhals_expr__(self) -> None: ...
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
|
||||
# that raised a KeyError for result[0] during collection.
|
||||
return [result.loc[0][0] for result in self(df)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[DaskLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
try:
|
||||
return [
|
||||
df._native_frame[column_name]
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
return [df.native.iloc[:, i] for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def _with_callable(
|
||||
self,
|
||||
# First argument to `call` should be `dx.Series`
|
||||
call: Callable[..., dx.Series],
|
||||
/,
|
||||
expr_name: str = "",
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
**expressifiable_args: Self | Any,
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
native_results: list[dx.Series] = []
|
||||
native_series_list = self._call(df)
|
||||
other_native_series = {
|
||||
key: maybe_evaluate_expr(df, value)
|
||||
for key, value in expressifiable_args.items()
|
||||
}
|
||||
for native_series in native_series_list:
|
||||
result_native = call(native_series, **other_native_series)
|
||||
native_results.append(result_native)
|
||||
return native_results
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=f"{self._function_name}->{expr_name}",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=scalar_kwargs,
|
||||
)
|
||||
|
||||
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
||||
current_alias_output_names = self._alias_output_names
|
||||
alias_output_names = (
|
||||
None
|
||||
if func is None
|
||||
else func
|
||||
if current_alias_output_names is None
|
||||
else lambda output_names: func(current_alias_output_names(output_names))
|
||||
)
|
||||
return type(self)(
|
||||
call=self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__add__(other), "__add__", other=other
|
||||
)
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__sub__(other), "__sub__", other=other
|
||||
)
|
||||
|
||||
def __rsub__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other - expr, "__rsub__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__mul__(other), "__mul__", other=other
|
||||
)
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__truediv__(other), "__truediv__", other=other
|
||||
)
|
||||
|
||||
def __rtruediv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other / expr, "__rtruediv__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other
|
||||
)
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other // expr, "__rfloordiv__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__pow__(other), "__pow__", other=other
|
||||
)
|
||||
|
||||
def __rpow__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other**expr, "__rpow__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__mod__(other), "__mod__", other=other
|
||||
)
|
||||
|
||||
def __rmod__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other % expr, "__rmod__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override]
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__eq__(other), "__eq__", other=other
|
||||
)
|
||||
|
||||
def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override]
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__ne__(other), "__ne__", other=other
|
||||
)
|
||||
|
||||
def __ge__(self, other: DaskExpr | Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__ge__(other), "__ge__", other=other
|
||||
)
|
||||
|
||||
def __gt__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__gt__(other), "__gt__", other=other
|
||||
)
|
||||
|
||||
def __le__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__le__(other), "__le__", other=other
|
||||
)
|
||||
|
||||
def __lt__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__lt__(other), "__lt__", other=other
|
||||
)
|
||||
|
||||
def __and__(self, other: DaskExpr | Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__and__(other), "__and__", other=other
|
||||
)
|
||||
|
||||
def __or__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__or__(other), "__or__", other=other
|
||||
)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
|
||||
|
||||
def mean(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
|
||||
|
||||
def median(self) -> Self:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
def func(s: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
|
||||
if not dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
return s.median_approximate().to_series()
|
||||
|
||||
return self._with_callable(func, "median")
|
||||
|
||||
def min(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.min().to_series(), "min")
|
||||
|
||||
def max(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.max().to_series(), "max")
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.std(ddof=ddof).to_series(),
|
||||
"std",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.var(ddof=ddof).to_series(),
|
||||
"var",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def skew(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
|
||||
|
||||
def shift(self, n: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.shift(n), "shift")
|
||||
|
||||
def cum_sum(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
# https://github.com/dask/dask/issues/11802
|
||||
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(
|
||||
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
|
||||
)
|
||||
|
||||
def cum_min(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
|
||||
|
||||
def cum_max(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
|
||||
|
||||
def cum_prod(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).sum(),
|
||||
"rolling_sum",
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).mean(),
|
||||
"rolling_mean",
|
||||
)
|
||||
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).var(),
|
||||
"rolling_var",
|
||||
)
|
||||
else:
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).std(),
|
||||
"rolling_std",
|
||||
)
|
||||
else:
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def sum(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
|
||||
|
||||
def count(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.count().to_series(), "count")
|
||||
|
||||
def round(self, decimals: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.round(decimals), "round")
|
||||
|
||||
def unique(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.unique(), "unique")
|
||||
|
||||
def drop_nulls(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
|
||||
|
||||
def abs(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.abs(), "abs")
|
||||
|
||||
def all(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.all(
|
||||
axis=None, skipna=True, split_every=False, out=None
|
||||
).to_series(),
|
||||
"all",
|
||||
)
|
||||
|
||||
def any(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
|
||||
"any",
|
||||
)
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
if value is not None:
|
||||
res_ser = expr.fillna(value)
|
||||
else:
|
||||
res_ser = (
|
||||
expr.ffill(limit=limit)
|
||||
if strategy == "forward"
|
||||
else expr.bfill(limit=limit)
|
||||
)
|
||||
return res_ser
|
||||
|
||||
return self._with_callable(func, "fillna")
|
||||
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, lower_bound, upper_bound: expr.clip(
|
||||
lower=lower_bound, upper=upper_bound
|
||||
),
|
||||
"clip",
|
||||
lower_bound=lower_bound,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
|
||||
def diff(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.diff(), "diff")
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
|
||||
)
|
||||
|
||||
def is_null(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isna(), "is_null")
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
expr.dtype, self._version, self._implementation
|
||||
)
|
||||
if dtype.is_numeric():
|
||||
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
|
||||
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self._with_callable(func, "is_null")
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.size.to_series(), "len")
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
if interpolation == "linear":
|
||||
|
||||
def func(expr: dx.Series, quantile: float) -> dx.Series:
|
||||
if expr.npartitions > 1:
|
||||
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
|
||||
raise NotImplementedError(msg)
|
||||
return expr.quantile(
|
||||
q=quantile, method="dask"
|
||||
).to_series() # pragma: no cover
|
||||
|
||||
return self._with_callable(func, "quantile", quantile=quantile)
|
||||
else:
|
||||
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def is_first_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
|
||||
return frame[col_token].isin(first_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_first_distinct")
|
||||
|
||||
def is_last_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
|
||||
return frame[col_token].isin(last_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_last_distinct")
|
||||
|
||||
def is_unique(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
return (
|
||||
expr.to_frame()
|
||||
.groupby(_name, dropna=False)
|
||||
.transform("size", meta=(_name, int))
|
||||
== 1
|
||||
)
|
||||
|
||||
return self._with_callable(func, "is_unique")
|
||||
|
||||
def is_in(self, other: Any) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isin(other), "is_in")
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.isna().sum().to_series(), "null_count"
|
||||
)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
# pandas is a required dependency of dask so it's safe to import this
|
||||
from narwhals._pandas_like.group_by import PandasLikeGroupBy
|
||||
|
||||
if not partition_by:
|
||||
assert order_by # noqa: S101
|
||||
|
||||
# This is something like `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self(df.sort(*order_by, descending=False, nulls_last=False))
|
||||
elif not self._is_elementary(): # pragma: no cover
|
||||
msg = (
|
||||
"Only elementary expressions are supported for `.over` in dask.\n\n"
|
||||
"Please see: "
|
||||
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
elif order_by:
|
||||
# Wrong results https://github.com/dask/dask/issues/11806.
|
||||
msg = "`over` with `order_by` is not yet supported in Dask."
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
function_name = PandasLikeGroupBy._leaf_name(self)
|
||||
try:
|
||||
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
|
||||
except KeyError:
|
||||
# window functions are unsupported: https://github.com/dask/dask/issues/11806
|
||||
msg = (
|
||||
f"Unsupported function: {function_name} in `over` context.\n\n"
|
||||
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
|
||||
)
|
||||
raise NotImplementedError(msg) from None
|
||||
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# https://github.com/dask/dask/issues/11804
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*`meta` is not specified",
|
||||
category=UserWarning,
|
||||
)
|
||||
grouped = df.native.groupby(partition_by)
|
||||
if dask_function_name == "size":
|
||||
if len(output_names) != 1: # pragma: no cover
|
||||
msg = "Safety check failed, please report a bug."
|
||||
raise AssertionError(msg)
|
||||
res_native = grouped.transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
).to_frame(output_names[0])
|
||||
else:
|
||||
res_native = grouped[list(output_names)].transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
)
|
||||
result_frame = df._with_native(
|
||||
res_native.rename(columns=dict(zip(output_names, aliases)))
|
||||
).native
|
||||
return [result_frame[name] for name in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
return expr.astype(native_dtype)
|
||||
|
||||
return self._with_callable(func, "cast")
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.isfinite, "is_finite")
|
||||
|
||||
def log(self, base: float) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
def _log(expr: dx.Series) -> dx.Series:
|
||||
return da.log(expr) / da.log(base)
|
||||
|
||||
return self._with_callable(_log, "log")
|
||||
|
||||
def exp(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.exp, "exp")
|
||||
|
||||
def sqrt(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.sqrt, "sqrt")
|
||||
|
||||
@property
|
||||
def str(self) -> DaskExprStringNamespace:
|
||||
return DaskExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DaskExprDateTimeNamespace:
|
||||
return DaskExprDateTimeNamespace(self)
|
||||
|
||||
arg_max: not_implemented = not_implemented()
|
||||
arg_min: not_implemented = not_implemented()
|
||||
arg_true: not_implemented = not_implemented()
|
||||
ewm_mean: not_implemented = not_implemented()
|
||||
gather_every: not_implemented = not_implemented()
|
||||
head: not_implemented = not_implemented()
|
||||
map_batches: not_implemented = not_implemented()
|
||||
mode: not_implemented = not_implemented()
|
||||
sample: not_implemented = not_implemented()
|
||||
rank: not_implemented = not_implemented()
|
||||
replace_strict: not_implemented = not_implemented()
|
||||
sort: not_implemented = not_implemented()
|
||||
tail: not_implemented = not_implemented()
|
||||
|
||||
# namespaces
|
||||
list: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
cat: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
struct: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
176
.venv/lib/python3.12/site-packages/narwhals/_dask/expr_dt.py
Normal file
176
.venv/lib/python3.12/site-packages/narwhals/_dask/expr_dt.py
Normal file
@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import DateTimeNamespace
|
||||
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._pandas_like.utils import (
|
||||
ALIAS_DICT,
|
||||
calculate_timestamp_date,
|
||||
calculate_timestamp_datetime,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
|
||||
class DaskExprDateTimeNamespace(
|
||||
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
|
||||
):
|
||||
def date(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
|
||||
|
||||
def year(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
|
||||
|
||||
def month(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
|
||||
|
||||
def day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
|
||||
|
||||
def hour(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
|
||||
|
||||
def minute(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
|
||||
|
||||
def second(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
|
||||
|
||||
def millisecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond // 1000, "millisecond"
|
||||
)
|
||||
|
||||
def microsecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond, "microsecond"
|
||||
)
|
||||
|
||||
def nanosecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
|
||||
)
|
||||
|
||||
def ordinal_day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.dayofyear, "ordinal_day"
|
||||
)
|
||||
|
||||
def weekday(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
|
||||
"weekday",
|
||||
)
|
||||
|
||||
def to_string(self, format: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
|
||||
"strftime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
|
||||
if time_zone is not None
|
||||
else expr.dt.tz_localize(None),
|
||||
"tz_localize",
|
||||
time_zone=time_zone,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> DaskExpr:
|
||||
def func(s: dx.Series, time_zone: str) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
if dtype.time_zone is None: # type: ignore[attr-defined]
|
||||
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
|
||||
|
||||
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
|
||||
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
|
||||
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
is_pyarrow_dtype = "pyarrow" in str(dtype)
|
||||
mask_na = s.isna()
|
||||
dtypes = self.compliant._version.dtypes
|
||||
if dtype == dtypes.Date:
|
||||
# Date is only supported in pandas dtypes if pyarrow-backed
|
||||
s_cast = s.astype("Int32[pyarrow]")
|
||||
result = calculate_timestamp_date(s_cast, time_unit)
|
||||
elif isinstance(dtype, dtypes.Datetime):
|
||||
original_time_unit = dtype.time_unit
|
||||
s_cast = (
|
||||
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
|
||||
)
|
||||
result = calculate_timestamp_datetime(
|
||||
s_cast, original_time_unit, time_unit
|
||||
)
|
||||
else:
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
return result.where(~mask_na) # pyright: ignore[reportReturnType]
|
||||
|
||||
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
|
||||
|
||||
def total_minutes(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
|
||||
)
|
||||
|
||||
def total_seconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
|
||||
)
|
||||
|
||||
def total_milliseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
|
||||
"total_milliseconds",
|
||||
)
|
||||
|
||||
def total_microseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
|
||||
"total_microseconds",
|
||||
)
|
||||
|
||||
def total_nanoseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
|
||||
)
|
||||
|
||||
def truncate(self, every: str) -> DaskExpr:
|
||||
interval = Interval.parse(every)
|
||||
unit = interval.unit
|
||||
if unit in {"mo", "q", "y"}:
|
||||
msg = f"Truncating to {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
|
||||
|
||||
def offset_by(self, by: str) -> DaskExpr:
|
||||
def func(s: dx.Series, by: str) -> dx.Series:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
unit = interval.unit
|
||||
if unit in {"y", "q", "mo", "d", "ns"}:
|
||||
msg = f"Offsetting by {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
offset = interval.to_timedelta()
|
||||
return s.add(offset)
|
||||
|
||||
return self.compliant._with_callable(func, "offset_by", by=by)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user