some new features

This commit is contained in:
ilgazca
2025-07-30 17:09:11 +03:00
parent db5d46760a
commit 8019bd3b7c
20616 changed files with 4375466 additions and 8 deletions

View File

@ -0,0 +1,186 @@
from __future__ import annotations
import typing as _t
from narwhals import dependencies, dtypes, exceptions, selectors
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
is_ordered_categorical,
maybe_align_index,
maybe_convert_dtypes,
maybe_get_index,
maybe_reset_index,
maybe_set_index,
)
from narwhals.dataframe import DataFrame, LazyFrame
from narwhals.dtypes import (
Array,
Binary,
Boolean,
Categorical,
Date,
Datetime,
Decimal,
Duration,
Enum,
Field,
Float32,
Float64,
Int8,
Int16,
Int32,
Int64,
Int128,
List,
Object,
String,
Struct,
Time,
UInt8,
UInt16,
UInt32,
UInt64,
UInt128,
Unknown,
)
from narwhals.expr import Expr
from narwhals.functions import (
all_ as all,
all_horizontal,
any_horizontal,
coalesce,
col,
concat,
concat_str,
exclude,
from_arrow,
from_dict,
from_numpy,
len_ as len,
lit,
max,
max_horizontal,
mean,
mean_horizontal,
median,
min,
min_horizontal,
new_series,
nth,
read_csv,
read_parquet,
scan_csv,
scan_parquet,
show_versions,
sum,
sum_horizontal,
when,
)
from narwhals.schema import Schema
from narwhals.series import Series
from narwhals.translate import (
from_native,
get_native_namespace,
narwhalify,
to_native,
to_py_scalar,
)
__version__: str
__all__ = [
"Array",
"Binary",
"Boolean",
"Categorical",
"DataFrame",
"Date",
"Datetime",
"Decimal",
"Duration",
"Enum",
"Expr",
"Field",
"Float32",
"Float64",
"Implementation",
"Int8",
"Int16",
"Int32",
"Int64",
"Int128",
"LazyFrame",
"List",
"Object",
"Schema",
"Series",
"String",
"Struct",
"Time",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
"all",
"all_horizontal",
"any_horizontal",
"coalesce",
"col",
"concat",
"concat_str",
"dependencies",
"dtypes",
"exceptions",
"exclude",
"from_arrow",
"from_dict",
"from_native",
"from_numpy",
"generate_temporary_column_name",
"get_native_namespace",
"is_ordered_categorical",
"len",
"lit",
"max",
"max_horizontal",
"maybe_align_index",
"maybe_convert_dtypes",
"maybe_get_index",
"maybe_reset_index",
"maybe_set_index",
"mean",
"mean_horizontal",
"median",
"min",
"min_horizontal",
"narwhalify",
"new_series",
"nth",
"read_csv",
"read_parquet",
"scan_csv",
"scan_parquet",
"selectors",
"show_versions",
"sum",
"sum_horizontal",
"to_native",
"to_py_scalar",
"when",
]
def __getattr__(name: _t.Literal["__version__"]) -> str: # type: ignore[misc]
if name == "__version__":
global __version__ # noqa: PLW0603
from importlib import metadata
__version__ = metadata.version(__name__)
return __version__
else:
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)

View File

@ -0,0 +1,761 @@
from __future__ import annotations
from collections.abc import Collection, Iterator, Mapping, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._compliant import EagerDataFrame
from narwhals._expression_parsing import ExprKind
from narwhals._utils import (
Implementation,
Version,
check_column_names_are_unique,
convert_str_slice_to_int_slice,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
scale_bytes,
supports_arrow_c_stream,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ShapeError
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pandas as pd
import polars as pl
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.group_by import ArrowGroupBy
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
ChunkedArrayAny,
Mask,
Order,
)
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
from narwhals._translate import IntoArrowTable
from narwhals._utils import Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import (
JoinStrategy,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_1DArray,
_2DArray,
_SliceIndex,
_SliceName,
)
JoinType: TypeAlias = Literal[
"left semi",
"right semi",
"left anti",
"right anti",
"inner",
"left outer",
"right outer",
"full outer",
]
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
class ArrowDataFrame(
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
):
_implementation = Implementation.PYARROW
def __init__(
self,
native_dataframe: pa.Table,
*,
version: Version,
validate_column_names: bool,
validate_backend_version: bool = False,
) -> None:
if validate_column_names:
check_column_names_are_unique(native_dataframe.column_names)
if validate_backend_version:
self._validate_backend_version()
self._native_frame = native_dataframe
self._version = version
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
backend_version = context._implementation._backend_version()
if cls._is_native(data):
native = data
elif backend_version >= (14,) or isinstance(data, Collection):
native = pa.table(data)
elif supports_arrow_c_stream(data): # pragma: no cover
msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}."
raise ModuleNotFoundError(msg)
else: # pragma: no cover
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
raise TypeError(msg)
return cls.from_native(native, context=context)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self:
from narwhals.schema import Schema
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
if pa_schema and not data:
native = pa_schema.empty_table()
else:
native = pa.Table.from_pydict(data, schema=pa_schema)
return cls.from_native(native, context=context)
@staticmethod
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
return isinstance(obj, pa.Table)
@classmethod
def from_native(cls, data: pa.Table, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version, validate_column_names=True)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
arrays = [pa.array(val) for val in data.T]
if isinstance(schema, (Mapping, Schema)):
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
else:
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
return cls.from_native(native, context=context)
def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
return ArrowNamespace(version=self._version)
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
return self._implementation.to_native_namespace()
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def __narwhals_dataframe__(self) -> Self:
return self
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version, validate_column_names=False)
def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self:
return self.__class__(
df, version=self._version, validate_column_names=validate_column_names
)
@property
def shape(self) -> tuple[int, int]:
return self.native.shape
def __len__(self) -> int:
return len(self.native)
def row(self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self.native.itercolumns())
@overload
def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
@overload
def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...
@overload
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
if not named:
return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value]
return self.native.to_pylist()
def iter_columns(self) -> Iterator[ArrowSeries]:
for name, series in zip(self.columns, self.native.itercolumns()):
yield ArrowSeries.from_native(series, context=self, name=name)
_iter_columns = iter_columns
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
df = self.native
num_rows = df.num_rows
if not named:
for i in range(0, num_rows, buffer_size):
rows = df[i : i + buffer_size].to_pydict().values()
yield from zip(*rows)
else:
for i in range(0, num_rows, buffer_size):
yield from df[i : i + buffer_size].to_pylist()
def get_column(self, name: str) -> ArrowSeries:
if not isinstance(name, str):
msg = f"Expected str, got: {type(name)}"
raise TypeError(msg)
return ArrowSeries.from_native(self.native[name], context=self, name=name)
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray:
return self.native.__array__(dtype, copy=copy)
def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
if len(rows) == 0:
return self._with_native(self.native.slice(0, 0))
if self._backend_version < (18,) and isinstance(rows, tuple):
rows = list(rows)
return self._with_native(self.native.take(rows))
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
start = rows.start or 0
stop = rows.stop if rows.stop is not None else len(self.native)
if start < 0:
start = len(self.native) + start
if stop < 0:
stop = len(self.native) + stop
if rows.step is not None and rows.step != 1:
msg = "Slicing with step is not supported on PyArrow tables"
raise NotImplementedError(msg)
return self._with_native(self.native.slice(start, stop - start))
def _select_slice_name(self, columns: _SliceName) -> Self:
start, stop, step = convert_str_slice_to_int_slice(columns, self.columns)
return self._with_native(self.native.select(self.columns[start:stop:step]))
def _select_slice_index(self, columns: _SliceIndex | range) -> Self:
return self._with_native(
self.native.select(self.columns[columns.start : columns.stop : columns.step])
)
def _select_multi_index(
self, columns: SizedMultiIndexSelector[ChunkedArrayAny]
) -> Self:
selector: Sequence[int]
if isinstance(columns, pa.ChunkedArray):
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
selector = cast("Sequence[int]", columns.to_pylist())
# TODO @dangotbanned: Fix upstream, it is actually much narrower
# **Doesn't accept `ndarray`**
elif is_numpy_array_1d(columns):
selector = columns.tolist()
else:
selector = columns
return self._with_native(self.native.select(selector))
def _select_multi_name(
self, columns: SizedMultiNameSelector[ChunkedArrayAny]
) -> Self:
selector: Sequence[str] | _1DArray
if isinstance(columns, pa.ChunkedArray):
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
selector = cast("Sequence[str]", columns.to_pylist())
else:
selector = columns
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
@property
def schema(self) -> dict[str, DType]:
schema = self.native.schema
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in zip(schema.names, schema.types)
}
def collect_schema(self) -> dict[str, DType]:
return self.schema
def estimated_size(self, unit: SizeUnit) -> int | float:
sz = self.native.nbytes
return scale_bytes(sz, unit)
explode = not_implemented()
@property
def columns(self) -> list[str]:
return self.native.column_names
def simple_select(self, *column_names: str) -> Self:
return self._with_native(
self.native.select(list(column_names)), validate_column_names=False
)
def select(self, *exprs: ArrowExpr) -> Self:
new_series = self._evaluate_into_exprs(*exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._with_native(
self.native.__class__.from_arrays([]), validate_column_names=False
)
names = [s.name for s in new_series]
align = new_series[0]._align_full_broadcast
reshaped = align(*new_series)
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
return self._with_native(df, validate_column_names=True)
def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny:
length = len(self)
if not other._broadcast:
if (len_other := len(other)) != length:
msg = f"Expected object of length {length}, got: {len_other}."
raise ShapeError(msg)
return other.native
value = other.native[0]
return pa.chunked_array([pa.repeat(value, length)])
def with_columns(self, *exprs: ArrowExpr) -> Self:
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
# All `pyarrow` data is immutable, so this is fine
native_frame = self.native
new_columns = self._evaluate_into_exprs(*exprs)
columns = self.columns
for col_value in new_columns:
col_name = col_value.name
column = self._extract_comparand(col_value)
native_frame = (
native_frame.set_column(columns.index(col_name), col_name, column=column)
if col_name in columns
else native_frame.append_column(col_name, column=column)
)
return self._with_native(native_frame, validate_column_names=False)
def group_by(
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
) -> ArrowGroupBy:
from narwhals._arrow.group_by import ArrowGroupBy
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_to_join_map: dict[str, JoinType] = {
"anti": "left anti",
"semi": "left semi",
"inner": "inner",
"left": "left outer",
"full": "full outer",
}
if how == "cross":
plx = self.__narwhals_namespace__()
key_token = generate_temporary_column_name(
n_bytes=8, columns=[*self.columns, *other.columns]
)
return self._with_native(
self.with_columns(
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
)
.native.join(
other.with_columns(
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
).native,
keys=key_token,
right_keys=key_token,
join_type="inner",
right_suffix=suffix,
)
.drop([key_token])
)
coalesce_keys = how != "full" # polars full join does not coalesce keys
return self._with_native(
self.native.join(
other.native,
keys=left_on or [], # type: ignore[arg-type]
right_keys=right_on, # type: ignore[arg-type]
join_type=how_to_join_map[how],
right_suffix=suffix,
coalesce_keys=coalesce_keys,
)
)
join_asof = not_implemented()
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(to_drop), validate_column_names=False)
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._with_native(self.native.drop_null(), validate_column_names=False)
plx = self.__narwhals_namespace__()
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
return self.filter(mask)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
order: Order = "descending" if descending else "ascending"
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
else:
sorting = [
(key, "descending" if is_descending else "ascending")
for key, is_descending in zip(by, descending)
]
null_placement = "at_end" if nulls_last else "at_start"
return self._with_native(
self.native.sort_by(sorting, null_placement=null_placement),
validate_column_names=False,
)
def to_pandas(self) -> pd.DataFrame:
return self.native.to_pandas()
def to_polars(self) -> pl.DataFrame:
import polars as pl # ignore-banned-import
return pl.from_arrow(self.native) # type: ignore[return-value]
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
import numpy as np # ignore-banned-import
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])
return arr
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
it = self.iter_columns()
if as_series:
return {ser.name: ser for ser in it}
return {ser.name: ser.to_list() for ser in it}
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
plx = self.__narwhals_namespace__()
if order_by is None:
import numpy as np # ignore-banned-import
data = pa.array(np.arange(len(self), dtype=np.int64))
row_index = plx._expr._from_series(
plx._series.from_iterable(data, context=self, name=name)
)
else:
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
return self.select(row_index, plx.all())
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native: Mask | ChunkedArrayAny = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
return self._with_native(
self.native.filter(mask_native), validate_column_names=False
)
def head(self, n: int) -> Self:
df = self.native
if n >= 0:
return self._with_native(df.slice(0, n), validate_column_names=False)
else:
num_rows = df.num_rows
return self._with_native(
df.slice(0, max(0, num_rows + n)), validate_column_names=False
)
def tail(self, n: int) -> Self:
df = self.native
if n >= 0:
num_rows = df.num_rows
return self._with_native(
df.slice(max(0, num_rows - n)), validate_column_names=False
)
else:
return self._with_native(df.slice(abs(n)), validate_column_names=False)
def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny:
if backend is None:
return self
elif backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
df = self.native # noqa: F841
return DuckDBLazyFrame(
duckdb.table("df"), validate_backend_version=True, version=self._version
)
elif backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsLazyFrame
return PolarsLazyFrame(
cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(),
validate_backend_version=True,
version=self._version,
)
elif backend is Implementation.DASK:
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
validate_backend_version=True,
version=self._version,
)
elif backend.is_ibis():
import ibis # ignore-banned-import
from narwhals._ibis.dataframe import IbisLazyFrame
return IbisLazyFrame(
ibis.memtable(self.native, columns=self.columns),
validate_backend_version=True,
version=self._version,
)
raise AssertionError # pragma: no cover
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is Implementation.PYARROW or backend is None:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self.native, version=self._version, validate_column_names=False
)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
cast("pl.DataFrame", pl.from_arrow(self.native)),
validate_backend_version=True,
version=self._version,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise AssertionError(msg) # pragma: no cover
def clone(self) -> Self:
return self._with_native(self.native, validate_column_names=False)
def item(self, row: int | None, column: int | str | None) -> Any:
from narwhals._arrow.series import maybe_extract_py_scalar
if row is None and column is None:
if self.shape != (1, 1):
msg = (
"can only call `.item()` if the dataframe is of shape (1, 1),"
" or if explicit row/col values are provided;"
f" frame has shape {self.shape!r}"
)
raise ValueError(msg)
return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True)
elif row is None or column is None:
msg = "cannot call `.item()` with only one of `row` or `column`"
raise ValueError(msg)
_col = self.columns.index(column) if isinstance(column, str) else column
return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True)
def rename(self, mapping: Mapping[str, str]) -> Self:
names: dict[str, str] | list[str]
if self._backend_version >= (17,):
names = cast("dict[str, str]", mapping)
else: # pragma: no cover
names = [mapping.get(c, c) for c in self.columns]
return self._with_native(self.native.rename_columns(names))
def write_parquet(self, file: str | Path | BytesIO) -> None:
import pyarrow.parquet as pp
pp.write_table(self.native, file)
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None:
import pyarrow.csv as pa_csv
if file is None:
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(self.native, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
pa_csv.write_csv(self.native, file)
return None
def is_unique(self) -> ArrowSeries:
import numpy as np # ignore-banned-import
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
row_index = pa.array(np.arange(len(self)))
keep_idx = (
self.native.append_column(col_token, row_index)
.group_by(self.columns)
.aggregate([(col_token, "min"), (col_token, "max")])
)
native = pa.chunked_array(
pc.and_(
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
)
)
return ArrowSeries.from_native(native, context=self)
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self:
# The param `maintain_order` is only here for compatibility with the Polars API
# and has no effect on the output.
import numpy as np # ignore-banned-import
if subset and (error := self._check_columns_exist(subset)):
raise error
subset = list(subset or self.columns)
if keep in {"any", "first", "last"}:
from narwhals._arrow.group_by import ArrowGroupBy
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
keep_idx_native = (
self.native.append_column(col_token, pa.array(np.arange(len(self))))
.group_by(subset)
.aggregate([(col_token, agg_func)])
.column(f"{col_token}_{agg_func}")
)
return self._with_native(
self.native.take(keep_idx_native), validate_column_names=False
)
keep_idx = self.simple_select(*subset).is_unique()
plx = self.__narwhals_namespace__()
return self.filter(plx._expr._from_series(keep_idx))
def gather_every(self, n: int, offset: int) -> Self:
return self._with_native(self.native[offset::n], validate_column_names=False)
def to_arrow(self) -> pa.Table:
return self.native
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self:
import numpy as np # ignore-banned-import
num_rows = len(self)
if n is None and fraction is not None:
n = int(num_rows * fraction)
rng = np.random.default_rng(seed=seed)
idx = np.arange(num_rows)
mask = rng.choice(idx, size=n, replace=with_replacement)
return self._with_native(self.native.take(mask), validate_column_names=False)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
n_rows = len(self)
index_ = [] if index is None else index
on_ = [c for c in self.columns if c not in index_] if on is None else on
concat = (
partial(pa.concat_tables, promote_options="permissive")
if self._backend_version >= (14, 0, 0)
else pa.concat_tables
)
names = [*index_, variable_name, value_name]
return self._with_native(
concat(
[
pa.Table.from_arrays(
[
*(self.native.column(idx_col) for idx_col in index_),
cast(
"ChunkedArrayAny",
pa.array([on_col] * n_rows, pa.string()),
),
self.native.column(on_col),
],
names=names,
)
for on_col in on_
]
)
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes
pivot = not_implemented()

View File

@ -0,0 +1,172 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import pyarrow.compute as pc
from narwhals._arrow.series import ArrowSeries
from narwhals._compliant import EagerExpr
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
not_implemented,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._expression_parsing import ExprMetadata
from narwhals._utils import Version, _LimitedContext
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
_implementation: Implementation = Implementation.PYARROW
def __init__(
self,
call: EvalSeries[ArrowDataFrame, ArrowSeries],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[ArrowDataFrame],
alias_output_names: AliasNames | None,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
implementation: Implementation | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._depth = depth
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[ArrowDataFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
try:
return [
ArrowSeries(
df.native[column_name], name=column_name, version=df._version
)
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
tbl = df.native
cols = df.columns
return [
ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
for i in column_indices
]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
return ArrowNamespace(version=self._version)
def __narwhals_expr__(self) -> None: ...
def _reuse_series_extra_kwargs(
self, *, returns_scalar: bool = False
) -> dict[str, Any]:
return {"_return_py_scalar": False} if returns_scalar else {}
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
if (
partition_by
and self._metadata is not None
and not self._metadata.is_scalar_like
):
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
raise NotImplementedError(msg)
if not partition_by:
# e.g. `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
assert order_by # noqa: S101
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
token = generate_temporary_column_name(8, df.columns)
df = df.with_row_index(token, order_by=None).sort(
*order_by, descending=False, nulls_last=False
)
result = self(df.drop([token], strict=True))
# TODO(marco): is there a way to do this efficiently without
# doing 2 sorts? Here we're sorting the dataframe and then
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
sorting_indices = pc.sort_indices(df.get_column(token).native)
return [s._with_native(s.native.take(sorting_indices)) for s in result]
else:
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if overlap := set(output_names).intersection(partition_by):
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
# we just don't support it yet.
msg = (
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
"This is not yet supported."
)
raise NotImplementedError(msg)
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
tmp = df.simple_select(*partition_by).join(
tmp,
how="left",
left_on=partition_by,
right_on=partition_by,
suffix="_right",
)
return [tmp.get_column(alias) for alias in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
ewm_mean = not_implemented()

View File

@ -0,0 +1,159 @@
from __future__ import annotations
import collections
from typing import TYPE_CHECKING, Any, ClassVar
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
from narwhals._compliant import EagerGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import generate_temporary_column_name
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
AggregateOptions,
Aggregation,
Incomplete,
)
from narwhals._compliant.typing import NarwhalsAggregation
from narwhals.typing import UniqueKeepStrategy
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance",
"len": "count",
"n_unique": "count_distinct",
"count": "count",
"all": "all",
"any": "any",
}
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
"any": "min",
"first": "min",
"last": "max",
}
def __init__(
self,
df: ArrowDataFrame,
keys: Sequence[ArrowExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._df = df
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
self._drop_null_keys = drop_null_keys
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
self._ensure_all_simple(exprs)
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()
exclude = (*self._keys, *self._output_key_names)
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, exclude
)
if expr._depth == 0:
# e.g. `agg(nw.len())`
if expr._function_name != "len": # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
new_column_names.append(aliases[0])
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
continue
function_name = self._leaf_name(expr)
if function_name in {"std", "var"}:
assert "ddof" in expr._scalar_kwargs # noqa: S101
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
option = pc.CountOptions(mode="only_valid")
elif function_name in {"all", "any"}:
option = pc.ScalarAggregateOptions(min_count=0)
else:
option = None
function_name = self._remap_expr_name(function_name)
new_column_names.extend(aliases)
expected_pyarrow_column_names.extend(
[f"{output_name}_{function_name}" for output_name in output_names]
)
aggs.extend(
[(output_name, function_name, option) for output_name in output_names]
)
result_simple = self._grouped.aggregate(aggs)
# Rename columns, being very careful
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
for idx, item in enumerate(expected_pyarrow_column_names):
expected_old_names_indices[item].append(idx)
if not (
set(result_simple.column_names) == set(expected_pyarrow_column_names)
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
): # pragma: no cover
msg = (
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
f"got {result_simple.column_names}, "
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
)
raise AssertionError(msg)
index_map: list[int] = [
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
]
new_column_names = [new_column_names[i] for i in index_map]
result_simple = result_simple.rename_columns(new_column_names)
return self.compliant._with_native(result_simple).rename(
dict(zip(self._keys, self._output_key_names))
)
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(
n_bytes=8, columns=self.compliant.columns
)
null_token: str = "__null_token_value__" # noqa: S105
table = self.compliant.native
it, separator_scalar = cast_to_comparable_string_types(
*(table[key] for key in self._keys), separator=""
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
key_values = concat_str(
*it, separator_scalar, null_handling="replace", null_replacement=null_token
)
table = table.add_column(i=0, field_=col_token, column=key_values)
for v in pc.unique(key_values):
t = self.compliant._with_native(
table.filter(pc.equal(table[col_token], v)).drop([col_token])
)
row = t.simple_select(*self._keys).row(0)
yield (
tuple(extract_py_scalar(el) for el in row),
t.simple_select(*self._df.columns),
)

View File

@ -0,0 +1,309 @@
from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Literal
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.selectors import ArrowSelectorNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation
if TYPE_CHECKING:
from collections.abc import Sequence
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
from narwhals._compliant.typing import ScalarKwargs
from narwhals._utils import Version
from narwhals.typing import IntoDType, NonNestedLiteral
class ArrowNamespace(
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"]
):
_implementation = Implementation.PYARROW
@property
def _dataframe(self) -> type[ArrowDataFrame]:
return ArrowDataFrame
@property
def _expr(self) -> type[ArrowExpr]:
return ArrowExpr
@property
def _series(self) -> type[ArrowSeries]:
return ArrowSeries
def __init__(self, *, version: Version) -> None:
self._version = version
def len(self) -> ArrowExpr:
# coverage bug? this is definitely hit
return self._expr( # pragma: no cover
lambda df: [
ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
],
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
)
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries.from_iterable(
data=[value], name="literal", context=self
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series
return self._expr(
lambda df: [_lit_arrow_series(df)],
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
)
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = chain.from_iterable(expr(df) for expr in exprs)
align = self._series._align_full_broadcast
it = (
(s.fill_null(True, None, None) for s in series) # noqa: FBT003
if ignore_nulls
else series
)
return [reduce(operator.and_, align(*it))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = chain.from_iterable(expr(df) for expr in exprs)
align = self._series._align_full_broadcast
it = (
(s.fill_null(False, None, None) for s in series) # noqa: FBT003
if ignore_nulls
else series
)
return [reduce(operator.or_, align(*it))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
it = chain.from_iterable(expr(df) for expr in exprs)
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
align = self._series._align_full_broadcast
return [reduce(operator.add, align(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
int_64 = self._version.dtypes.Int64()
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
align = self._series._align_full_broadcast
series = align(
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
)
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
init_series, *series = align(init_series, *series)
native_series = reduce(
pc.min_element_wise, [s.native for s in series], init_series.native
)
return [
ArrowSeries(native_series, name=init_series.name, version=self._version)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
init_series, *series = align(init_series, *series)
native_series = reduce(
pc.max_element_wise, [s.native for s in series], init_series.native
)
return [
ArrowSeries(native_series, name=init_series.name, version=self._version)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
if self._backend_version >= (14,):
return pa.concat_tables(dfs, promote_options="default")
return pa.concat_tables(dfs, promote=True) # pragma: no cover
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
names = list(chain.from_iterable(df.column_names for df in dfs))
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
return pa.Table.from_arrays(arrays, names=names)
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
cols_0 = dfs[0].column_names
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.column_names
if cols_current != cols_0:
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)
return pa.concat_tables(dfs)
@property
def selectors(self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace.from_namespace(self)
def when(self, predicate: ArrowExpr) -> ArrowWhen:
return ArrowWhen.from_expr(predicate, context=self)
def concat_str(
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
compliant_series_list = align(
*(chain.from_iterable(expr(df) for expr in exprs))
)
name = compliant_series_list[0].name
null_handling: Literal["skip", "emit_null"] = (
"skip" if ignore_nulls else "emit_null"
)
it, separator_scalar = cast_to_comparable_string_types(
*(s.native for s in compliant_series_list), separator=separator
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
compliant = self._series(
concat_str(*it, separator_scalar, null_handling=null_handling),
name=name,
version=self._version,
)
return [compliant]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def coalesce(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
init_series, *series = align(*chain.from_iterable(expr(df) for expr in exprs))
return [
ArrowSeries(
pc.coalesce(init_series.native, *(s.native for s in series)),
name=init_series.name,
version=self._version,
)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="coalesce",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
@property
def _then(self) -> type[ArrowThen]:
return ArrowThen
def _if_then_else(
self,
when: ChunkedArrayAny,
then: ChunkedArrayAny,
otherwise: ArrayOrScalar | NonNestedLiteral,
/,
) -> ChunkedArrayAny:
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
return pc.if_else(when, then, otherwise)
class ArrowThen(
CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr
):
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "whenthen"

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._arrow.expr import ArrowExpr
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
if TYPE_CHECKING:
from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401
from narwhals._arrow.series import ArrowSeries # noqa: F401
from narwhals._compliant.typing import ScalarKwargs
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
@property
def _selector(self) -> type[ArrowSelector]:
return ArrowSelector
class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "selector"
def _to_expr(self) -> ArrowExpr:
return ArrowExpr(
self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow as pa
from narwhals._arrow.utils import ArrowSeriesNamespace
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import Incomplete
class ArrowSeriesCatNamespace(ArrowSeriesNamespace):
def get_categories(self) -> ArrowSeries:
# NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes
chunks: Incomplete = self.native.chunks
return self.with_native(pa.concat_arrays(x.dictionary for x in chunks).unique())

View File

@ -0,0 +1,227 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit
from narwhals._constants import (
MS_PER_MINUTE,
MS_PER_SECOND,
NS_PER_MICROSECOND,
NS_PER_MILLISECOND,
NS_PER_MINUTE,
NS_PER_SECOND,
SECONDS_PER_DAY,
SECONDS_PER_MINUTE,
US_PER_MINUTE,
US_PER_SECOND,
)
from narwhals._duration import Interval
if TYPE_CHECKING:
from collections.abc import Mapping
from typing_extensions import TypeAlias
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny
from narwhals.dtypes import Datetime
from narwhals.typing import TimeUnit
UnitCurrent: TypeAlias = TimeUnit
UnitTarget: TypeAlias = TimeUnit
BinOpBroadcast: TypeAlias = Callable[[ChunkedArrayAny, ScalarAny], ChunkedArrayAny]
IntoRhs: TypeAlias = int
class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace):
_TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = {
"ns": NS_PER_SECOND,
"us": US_PER_SECOND,
"ms": MS_PER_SECOND,
"s": 1,
}
_TIMESTAMP_DATETIME_OP_FACTOR: ClassVar[
Mapping[tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]]
] = {
("ns", "us"): (floordiv_compat, 1_000),
("ns", "ms"): (floordiv_compat, 1_000_000),
("us", "ns"): (pc.multiply, NS_PER_MICROSECOND),
("us", "ms"): (floordiv_compat, 1_000),
("ms", "ns"): (pc.multiply, NS_PER_MILLISECOND),
("ms", "us"): (pc.multiply, 1_000),
("s", "ns"): (pc.multiply, NS_PER_SECOND),
("s", "us"): (pc.multiply, US_PER_SECOND),
("s", "ms"): (pc.multiply, MS_PER_SECOND),
}
@property
def unit(self) -> TimeUnit: # NOTE: Unsafe (native).
return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit
@property
def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals).
return cast("Datetime", self.compliant.dtype).time_zone
def to_string(self, format: str) -> ArrowSeries:
# PyArrow differs from other libraries in that %S also prints out
# the fractional part of the second...:'(
# https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
return self.with_native(pc.strftime(self.native, format))
def replace_time_zone(self, time_zone: str | None) -> ArrowSeries:
if time_zone is not None:
result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone)
else:
result = pc.local_timestamp(self.native)
return self.with_native(result)
def convert_time_zone(self, time_zone: str) -> ArrowSeries:
ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant
return self.with_native(ser.native.cast(pa.timestamp(self.unit, time_zone)))
def timestamp(self, time_unit: TimeUnit) -> ArrowSeries:
ser = self.compliant
dtypes = ser._version.dtypes
if isinstance(ser.dtype, dtypes.Datetime):
current = ser.dtype.time_unit
s_cast = self.native.cast(pa.int64())
if current == time_unit:
result = s_cast
elif item := self._TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
fn, factor = item
result = fn(s_cast, lit(factor))
else: # pragma: no cover
msg = f"unexpected time unit {current}, please report an issue at https://github.com/narwhals-dev/narwhals"
raise AssertionError(msg)
return self.with_native(result)
elif isinstance(ser.dtype, dtypes.Date):
time_s = pc.multiply(self.native.cast(pa.int32()), lit(SECONDS_PER_DAY))
factor = self._TIMESTAMP_DATE_FACTOR[time_unit]
return self.with_native(pc.multiply(time_s, lit(factor)))
else:
msg = "Input should be either of Date or Datetime type"
raise TypeError(msg)
def date(self) -> ArrowSeries:
return self.with_native(self.native.cast(pa.date32()))
def year(self) -> ArrowSeries:
return self.with_native(pc.year(self.native))
def month(self) -> ArrowSeries:
return self.with_native(pc.month(self.native))
def day(self) -> ArrowSeries:
return self.with_native(pc.day(self.native))
def hour(self) -> ArrowSeries:
return self.with_native(pc.hour(self.native))
def minute(self) -> ArrowSeries:
return self.with_native(pc.minute(self.native))
def second(self) -> ArrowSeries:
return self.with_native(pc.second(self.native))
def millisecond(self) -> ArrowSeries:
return self.with_native(pc.millisecond(self.native))
def microsecond(self) -> ArrowSeries:
arr = self.native
result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr))
return self.with_native(result)
def nanosecond(self) -> ArrowSeries:
result = pc.add(
pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native)
)
return self.with_native(result)
def ordinal_day(self) -> ArrowSeries:
return self.with_native(pc.day_of_year(self.native))
def weekday(self) -> ArrowSeries:
return self.with_native(pc.day_of_week(self.native, count_from_zero=False))
def total_minutes(self) -> ArrowSeries:
unit_to_minutes_factor = {
"s": SECONDS_PER_MINUTE,
"ms": MS_PER_MINUTE,
"us": US_PER_MINUTE,
"ns": NS_PER_MINUTE,
}
factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64())
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_seconds(self) -> ArrowSeries:
unit_to_seconds_factor = {
"s": 1,
"ms": MS_PER_SECOND,
"us": US_PER_SECOND,
"ns": NS_PER_SECOND,
}
factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64())
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_milliseconds(self) -> ArrowSeries:
unit_to_milli_factor = {
"s": 1e3, # seconds
"ms": 1, # milli
"us": 1e3, # micro
"ns": 1e6, # nano
}
factor = lit(unit_to_milli_factor[self.unit], type=pa.int64())
if self.unit == "s":
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_microseconds(self) -> ArrowSeries:
unit_to_micro_factor = {
"s": 1e6, # seconds
"ms": 1e3, # milli
"us": 1, # micro
"ns": 1e3, # nano
}
factor = lit(unit_to_micro_factor[self.unit], type=pa.int64())
if self.unit in {"s", "ms"}:
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_nanoseconds(self) -> ArrowSeries:
unit_to_nano_factor = {
"s": NS_PER_SECOND,
"ms": NS_PER_MILLISECOND,
"us": NS_PER_MICROSECOND,
"ns": 1,
}
factor = lit(unit_to_nano_factor[self.unit], type=pa.int64())
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
def truncate(self, every: str) -> ArrowSeries:
interval = Interval.parse(every)
return self.with_native(
pc.floor_temporal(self.native, interval.multiple, UNITS_DICT[interval.unit])
)
def offset_by(self, by: str) -> ArrowSeries:
interval = Interval.parse_no_constraints(by)
native = self.native
if interval.unit in {"y", "q", "mo"}:
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
raise NotImplementedError(msg)
dtype = self.compliant.dtype
datetime_dtype = self.version.dtypes.Datetime
if interval.unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
native_naive = pc.local_timestamp(native)
result = pc.assume_timezone(pc.add(native_naive, offset), dtype.time_zone)
return self.with_native(result)
if interval.unit == "ns": # pragma: no cover
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
else:
offset = lit(interval.to_timedelta())
return self.with_native(pc.add(native, offset))

View File

@ -0,0 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import ArrowSeriesNamespace
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
class ArrowSeriesListNamespace(ArrowSeriesNamespace):
def len(self) -> ArrowSeries:
return self.with_native(pc.list_value_length(self.native).cast(pa.uint32()))

View File

@ -0,0 +1,103 @@
from __future__ import annotations
import string
from typing import TYPE_CHECKING
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import Incomplete
class ArrowSeriesStringNamespace(ArrowSeriesNamespace):
def len_chars(self) -> ArrowSeries:
return self.with_native(pc.utf8_length(self.native))
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries:
fn = pc.replace_substring if literal else pc.replace_substring_regex
arr = fn(self.native, pattern, replacement=value, max_replacements=n)
return self.with_native(arr)
def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries:
return self.replace(pattern, value, literal=literal, n=-1)
def strip_chars(self, characters: str | None) -> ArrowSeries:
return self.with_native(
pc.utf8_trim(self.native, characters or string.whitespace)
)
def starts_with(self, prefix: str) -> ArrowSeries:
return self.with_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix)))
def ends_with(self, suffix: str) -> ArrowSeries:
return self.with_native(
pc.equal(self.slice(-len(suffix), None).native, lit(suffix))
)
def contains(self, pattern: str, *, literal: bool) -> ArrowSeries:
check_func = pc.match_substring if literal else pc.match_substring_regex
return self.with_native(check_func(self.native, pattern))
def slice(self, offset: int, length: int | None) -> ArrowSeries:
stop = offset + length if length is not None else None
return self.with_native(
pc.utf8_slice_codeunits(self.native, start=offset, stop=stop)
)
def split(self, by: str) -> ArrowSeries:
split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload]
return self.with_native(split_series)
def to_datetime(self, format: str | None) -> ArrowSeries:
format = parse_datetime_format(self.native) if format is None else format
timestamp_array = pc.strptime(self.native, format=format, unit="us")
return self.with_native(timestamp_array)
def to_date(self, format: str | None) -> ArrowSeries:
return self.to_datetime(format=format).dt.date()
def to_uppercase(self) -> ArrowSeries:
return self.with_native(pc.utf8_upper(self.native))
def to_lowercase(self) -> ArrowSeries:
return self.with_native(pc.utf8_lower(self.native))
def zfill(self, width: int) -> ArrowSeries:
binary_join: Incomplete = pc.binary_join_element_wise
native = self.native
hyphen, plus = lit("-"), lit("+")
first_char, remaining_chars = (
self.slice(0, 1).native,
self.slice(1, None).native,
)
# Conditions
less_than_width = pc.less(pc.utf8_length(native), lit(width))
starts_with_hyphen = pc.equal(first_char, hyphen)
starts_with_plus = pc.equal(first_char, plus)
conditions = pc.make_struct(
pc.and_(starts_with_hyphen, less_than_width),
pc.and_(starts_with_plus, less_than_width),
less_than_width,
)
# Cases
padded_remaining_chars = pc.utf8_lpad(remaining_chars, width - 1, padding="0")
result = pc.case_when(
conditions,
binary_join(
pa.repeat(hyphen, len(native)), padded_remaining_chars, ""
), # starts with hyphen and less than width
binary_join(
pa.repeat(plus, len(native)), padded_remaining_chars, ""
), # starts with plus and less than width
pc.utf8_lpad(native, width=width, padding="0"), # less than width
native,
)
return self.with_native(result)

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow.compute as pc
from narwhals._arrow.utils import ArrowSeriesNamespace
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
class ArrowSeriesStructNamespace(ArrowSeriesNamespace):
def field(self, name: str) -> ArrowSeries:
return self.with_native(pc.struct_field(self.native, name)).alias(name)

View File

@ -0,0 +1,72 @@
from __future__ import annotations # pragma: no cover
from typing import (
TYPE_CHECKING, # pragma: no cover
Any, # pragma: no cover
TypeVar, # pragma: no cover
)
if TYPE_CHECKING:
import sys
from typing import Generic, Literal
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
import pyarrow as pa
from pyarrow.__lib_pxi.table import (
AggregateOptions, # noqa: F401
Aggregation, # noqa: F401
)
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource]
Indices, # noqa: F401
Mask, # noqa: F401
Order, # noqa: F401
)
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries"
TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"]
NullPlacement: TypeAlias = Literal["at_start", "at_end"]
NativeIntervalUnit: TypeAlias = Literal[
"year",
"quarter",
"month",
"week",
"day",
"hour",
"minute",
"second",
"millisecond",
"microsecond",
"nanosecond",
]
ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any]
ArrayAny: TypeAlias = pa.Array[Any]
ArrayOrChunkedArray: TypeAlias = "ArrayAny | ChunkedArrayAny"
ScalarAny: TypeAlias = pa.Scalar[Any]
ArrayOrScalar: TypeAlias = "ArrayOrChunkedArray | ScalarAny"
ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrayAny, ChunkedArrayAny, ScalarAny)
ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrayAny, ChunkedArrayAny, ScalarAny)
_AsPyType = TypeVar("_AsPyType")
class _BasicDataType(pa.DataType, Generic[_AsPyType]): ...
Incomplete: TypeAlias = Any # pragma: no cover
"""
Marker for working code that fails on the stubs.
Common issues:
- Annotated for `Array`, but not `ChunkedArray`
- Relies on typing information that the stubs don't provide statically
- Missing attributes
- Incorrect return types
- Inconsistent use of generic/concrete types
- `_clone_signature` used on signatures that are not identical
"""

View File

@ -0,0 +1,443 @@
from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._compliant import EagerSeriesNamespace
from narwhals._utils import isinstance_or_issubclass
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from typing_extensions import TypeAlias, TypeIs
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import (
ArrayAny,
ArrayOrScalar,
ArrayOrScalarT1,
ArrayOrScalarT2,
ChunkedArrayAny,
NativeIntervalUnit,
ScalarAny,
)
from narwhals._duration import IntervalUnit
from narwhals._utils import Version
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, PythonLiteral
# NOTE: stubs don't allow for `ChunkedArray[StructArray]`
# Intended to represent the `.chunks` property storing `list[pa.StructArray]`
ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny
def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
def extract_regex(
strings: ChunkedArrayAny,
/,
pattern: str,
*,
options: Any = None,
memory_pool: Any = None,
) -> ChunkedArrayStructArray: ...
else:
from pyarrow.compute import extract_regex
from pyarrow.types import (
is_dictionary, # noqa: F401
is_duration,
is_fixed_size_list,
is_large_list,
is_list,
is_timestamp,
)
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
"y": "year",
"q": "quarter",
"mo": "month",
"d": "day",
"h": "hour",
"m": "minute",
"s": "second",
"ms": "millisecond",
"us": "microsecond",
"ns": "nanosecond",
}
lit = pa.scalar
"""Alias for `pyarrow.scalar`."""
def extract_py_scalar(value: Any, /) -> Any:
from narwhals._arrow.series import maybe_extract_py_scalar
return maybe_extract_py_scalar(value, return_py_scalar=True)
def chunked_array(
arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, /
) -> ChunkedArrayAny:
if isinstance(arr, pa.ChunkedArray):
return arr
if isinstance(arr, list):
return pa.chunked_array(arr, dtype)
else:
return pa.chunked_array([arr], arr.type)
def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
"""Create a strongly-typed Array instance with all elements null.
Uses the type of `series`, without upseting `mypy`.
"""
return pa.nulls(n, series.native.type)
def zeros(n: int, /) -> pa.Int64Array:
return pa.repeat(0, n)
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
dtypes = version.dtypes
if pa.types.is_int64(dtype):
return dtypes.Int64()
if pa.types.is_int32(dtype):
return dtypes.Int32()
if pa.types.is_int16(dtype):
return dtypes.Int16()
if pa.types.is_int8(dtype):
return dtypes.Int8()
if pa.types.is_uint64(dtype):
return dtypes.UInt64()
if pa.types.is_uint32(dtype):
return dtypes.UInt32()
if pa.types.is_uint16(dtype):
return dtypes.UInt16()
if pa.types.is_uint8(dtype):
return dtypes.UInt8()
if pa.types.is_boolean(dtype):
return dtypes.Boolean()
if pa.types.is_float64(dtype):
return dtypes.Float64()
if pa.types.is_float32(dtype):
return dtypes.Float32()
# bug in coverage? it shows `31->exit` (where `31` is currently the line number of
# the next line), even though both when the if condition is true and false are covered
if ( # pragma: no cover
pa.types.is_string(dtype)
or pa.types.is_large_string(dtype)
or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
):
return dtypes.String()
if pa.types.is_date32(dtype):
return dtypes.Date()
if is_timestamp(dtype):
return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
if is_duration(dtype):
return dtypes.Duration(time_unit=dtype.unit)
if pa.types.is_dictionary(dtype):
return dtypes.Categorical()
if pa.types.is_struct(dtype):
return dtypes.Struct(
[
dtypes.Field(
dtype.field(i).name,
native_to_narwhals_dtype(dtype.field(i).type, version),
)
for i in range(dtype.num_fields)
]
)
if is_list(dtype) or is_large_list(dtype):
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
if is_fixed_size_list(dtype):
return dtypes.Array(
native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
)
if pa.types.is_decimal(dtype):
return dtypes.Decimal()
if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
return dtypes.Time()
if pa.types.is_binary(dtype):
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912
dtypes = version.dtypes
if isinstance_or_issubclass(dtype, dtypes.Decimal):
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Float64):
return pa.float64()
if isinstance_or_issubclass(dtype, dtypes.Float32):
return pa.float32()
if isinstance_or_issubclass(dtype, dtypes.Int64):
return pa.int64()
if isinstance_or_issubclass(dtype, dtypes.Int32):
return pa.int32()
if isinstance_or_issubclass(dtype, dtypes.Int16):
return pa.int16()
if isinstance_or_issubclass(dtype, dtypes.Int8):
return pa.int8()
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return pa.uint64()
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return pa.uint32()
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return pa.uint16()
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return pa.uint8()
if isinstance_or_issubclass(dtype, dtypes.String):
return pa.string()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return pa.bool_()
if isinstance_or_issubclass(dtype, dtypes.Categorical):
return pa.dictionary(pa.uint32(), pa.string())
if isinstance_or_issubclass(dtype, dtypes.Datetime):
unit = dtype.time_unit
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pa.duration(dtype.time_unit)
if isinstance_or_issubclass(dtype, dtypes.Date):
return pa.date32()
if isinstance_or_issubclass(dtype, dtypes.List):
return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
return pa.struct(
[
(field.name, narwhals_to_native_dtype(field.dtype, version=version))
for field in dtype.fields
]
)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
inner = narwhals_to_native_dtype(dtype.inner, version=version)
list_size = dtype.size
return pa.list_(inner, list_size=list_size)
if isinstance_or_issubclass(dtype, dtypes.Time):
return pa.time64("ns")
if isinstance_or_issubclass(dtype, dtypes.Binary):
return pa.binary()
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def extract_native(
lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny
) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]:
"""Extract native objects in binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
"right-hand-side" operation (e.g. `__radd__`) can be tried.
If one of the two sides has a `_broadcast` flag, then extract the scalar
underneath it so that PyArrow can do its own broadcasting.
"""
from narwhals._arrow.series import ArrowSeries
if rhs is None: # pragma: no cover
return lhs.native, lit(None, type=lhs._type)
if isinstance(rhs, ArrowSeries):
if lhs._broadcast and not rhs._broadcast:
return lhs.native[0], rhs.native
if rhs._broadcast:
return lhs.native, rhs.native[0]
return lhs.native, rhs.native
if isinstance(rhs, list):
msg = "Expected Series or scalar, got list."
raise TypeError(msg)
return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
# The following lines are adapted from pandas' pyarrow implementation.
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
divided = pc.divide_checked(left, right)
# TODO @dangotbanned: Use a `TypeVar` in guards
# Narrowing to a `Union` isn't interacting well with the rest of the stubs
# https://github.com/zen-xu/pyarrow-stubs/pull/215
if pa.types.is_signed_integer(divided.type):
div_type = cast("pa._lib.Int64Type", divided.type)
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
has_one_negative_operand = pc.less(
pc.bit_wise_xor(left, right), lit(0, div_type)
)
result = pc.if_else(
pc.and_(has_remainder, has_one_negative_operand),
pc.subtract(divided, lit(1, div_type)),
divided,
)
else:
result = divided # pragma: no cover
result = result.cast(left.type)
else:
divided = pc.divide(left, right)
result = pc.floor(divided)
return result
def cast_for_truediv(
arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2
) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]:
# Lifted from:
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
# Ensure int / int -> float mirroring Python/Numpy behavior
# as pc.divide_checked(int, int) -> int
if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type):
# GH: 56645. # noqa: ERA001
# https://github.com/apache/arrow/issues/35563
return arrow_array.cast(pa.float64(), safe=False), pa_object.cast(
pa.float64(), safe=False
)
return arrow_array, pa_object
# Regex for date, time, separator and timezone components
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})"
SEP_RE = r"(?P<sep>\s|T)"
TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)?
HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$"
HM_RE = r"^(?P<hm>\d{2}:\d{2})$"
HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$"
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"
# Separate regexes for different date formats
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$"
DATE_FORMATS = (
(YMD_RE_NO_SEP, "%Y%m%d"),
(YMD_RE, "%Y-%m-%d"),
(DMY_RE, "%d-%m-%Y"),
(MDY_RE, "%m-%d-%Y"),
)
TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S"))
def _extract_regex_concat_arrays(
strings: ChunkedArrayAny,
/,
pattern: str,
*,
options: Any = None,
memory_pool: Any = None,
) -> pa.StructArray:
r = pa.concat_arrays(
extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks
)
return cast("pa.StructArray", r)
def parse_datetime_format(arr: ChunkedArrayAny) -> str:
"""Try to infer datetime format from StringArray."""
matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE)
if not pc.all(matches.is_valid()).as_py():
msg = (
"Unable to infer datetime format, provided format is not supported. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise NotImplementedError(msg)
separators = matches.field("sep")
tz = matches.field("tz")
# separators and time zones must be unique
if pc.count(pc.unique(separators)).as_py() > 1:
msg = "Found multiple separator values while inferring datetime format."
raise ValueError(msg)
if pc.count(pc.unique(tz)).as_py() > 1:
msg = "Found multiple timezone values while inferring datetime format."
raise ValueError(msg)
date_value = _parse_date_format(cast("pc.StringArray", matches.field("date")))
time_value = _parse_time_format(cast("pc.StringArray", matches.field("time")))
sep_value = separators[0].as_py()
tz_value = "%z" if tz[0].as_py() else ""
return f"{date_value}{sep_value}{time_value}{tz_value}"
def _parse_date_format(arr: pc.StringArray) -> str:
for date_rgx, date_fmt in DATE_FORMATS:
matches = pc.extract_regex(arr, pattern=date_rgx)
if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py():
return date_fmt
elif (
pc.all(matches.is_valid()).as_py()
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
):
return date_fmt.replace("-", date_sep_value)
msg = (
"Unable to infer datetime format. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise ValueError(msg)
def _parse_time_format(arr: pc.StringArray) -> str:
for time_rgx, time_fmt in TIME_FORMATS:
matches = pc.extract_regex(arr, pattern=time_rgx)
if pc.all(matches.is_valid()).as_py():
return time_fmt
return ""
def pad_series(
series: ArrowSeries, *, window_size: int, center: bool
) -> tuple[ArrowSeries, int]:
"""Pad series with None values on the left and/or right side, depending on the specified parameters.
Arguments:
series: The input ArrowSeries to be padded.
window_size: The desired size of the window.
center: Specifies whether to center the padding or not.
Returns:
A tuple containing the padded ArrowSeries and the offset value.
"""
if not center:
return series, 0
offset_left = window_size // 2
# subtract one if window_size is even
offset_right = offset_left - (window_size % 2 == 0)
pad_left = pa.array([None] * offset_left, type=series._type)
pad_right = pa.array([None] * offset_right, type=series._type)
concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
return series._with_native(concat), offset_left + offset_right
def cast_to_comparable_string_types(
*chunked_arrays: ChunkedArrayAny, separator: str
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
dtype = (
pa.string() # (PyArrow default)
if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
else pa.large_string()
)
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
LazyExpr,
LazyExprNamespace,
)
from narwhals._compliant.group_by import (
CompliantGroupBy,
DepthTrackingGroupBy,
EagerGroupBy,
)
from narwhals._compliant.namespace import (
CompliantNamespace,
DepthTrackingNamespace,
EagerNamespace,
LazyNamespace,
)
from narwhals._compliant.selectors import (
CompliantSelector,
CompliantSelectorNamespace,
EagerSelectorNamespace,
LazySelectorNamespace,
)
from narwhals._compliant.series import (
CompliantSeries,
EagerSeries,
EagerSeriesCatNamespace,
EagerSeriesDateTimeNamespace,
EagerSeriesHist,
EagerSeriesListNamespace,
EagerSeriesNamespace,
EagerSeriesStringNamespace,
EagerSeriesStructNamespace,
)
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantSeriesOrNativeExprT_co,
CompliantSeriesT,
EagerDataFrameT,
EagerSeriesT,
EvalNames,
EvalSeries,
IntoCompliantExpr,
NativeFrameT_co,
NativeSeriesT_co,
)
from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen
from narwhals._compliant.window import WindowInputs
__all__ = [
"CompliantDataFrame",
"CompliantExpr",
"CompliantExprT",
"CompliantFrameT",
"CompliantGroupBy",
"CompliantLazyFrame",
"CompliantNamespace",
"CompliantSelector",
"CompliantSelectorNamespace",
"CompliantSeries",
"CompliantSeriesOrNativeExprT_co",
"CompliantSeriesT",
"CompliantThen",
"CompliantWhen",
"DepthTrackingExpr",
"DepthTrackingGroupBy",
"DepthTrackingNamespace",
"EagerDataFrame",
"EagerDataFrameT",
"EagerExpr",
"EagerGroupBy",
"EagerNamespace",
"EagerSelectorNamespace",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesHist",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
"EagerSeriesT",
"EagerWhen",
"EvalNames",
"EvalSeries",
"IntoCompliantExpr",
"LazyExpr",
"LazyExprNamespace",
"LazyNamespace",
"LazySelectorNamespace",
"NativeFrameT_co",
"NativeSeriesT_co",
"WindowInputs",
]

View File

@ -0,0 +1,89 @@
"""`Expr` and `Series` namespace accessor protocols."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from narwhals._utils import CompliantT_co, _StoresCompliant
if TYPE_CHECKING:
from typing import Callable
from narwhals.typing import TimeUnit
__all__ = [
"CatNamespace",
"DateTimeNamespace",
"ListNamespace",
"NameNamespace",
"StringNamespace",
"StructNamespace",
]
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def get_categories(self) -> CompliantT_co: ...
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def to_string(self, format: str) -> CompliantT_co: ...
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
def date(self) -> CompliantT_co: ...
def year(self) -> CompliantT_co: ...
def month(self) -> CompliantT_co: ...
def day(self) -> CompliantT_co: ...
def hour(self) -> CompliantT_co: ...
def minute(self) -> CompliantT_co: ...
def second(self) -> CompliantT_co: ...
def millisecond(self) -> CompliantT_co: ...
def microsecond(self) -> CompliantT_co: ...
def nanosecond(self) -> CompliantT_co: ...
def ordinal_day(self) -> CompliantT_co: ...
def weekday(self) -> CompliantT_co: ...
def total_minutes(self) -> CompliantT_co: ...
def total_seconds(self) -> CompliantT_co: ...
def total_milliseconds(self) -> CompliantT_co: ...
def total_microseconds(self) -> CompliantT_co: ...
def total_nanoseconds(self) -> CompliantT_co: ...
def truncate(self, every: str) -> CompliantT_co: ...
def offset_by(self, by: str) -> CompliantT_co: ...
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def len(self) -> CompliantT_co: ...
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def keep(self) -> CompliantT_co: ...
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
def prefix(self, prefix: str) -> CompliantT_co: ...
def suffix(self, suffix: str) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def len_chars(self) -> CompliantT_co: ...
def replace(
self, pattern: str, value: str, *, literal: bool, n: int
) -> CompliantT_co: ...
def replace_all(
self, pattern: str, value: str, *, literal: bool
) -> CompliantT_co: ...
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
def starts_with(self, prefix: str) -> CompliantT_co: ...
def ends_with(self, suffix: str) -> CompliantT_co: ...
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
def split(self, by: str) -> CompliantT_co: ...
def to_datetime(self, format: str | None) -> CompliantT_co: ...
def to_date(self, format: str | None) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
def zfill(self, width: int) -> CompliantT_co: ...
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def field(self, name: str) -> CompliantT_co: ...

View File

@ -0,0 +1,486 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprT_contra,
CompliantLazyFrameAny,
CompliantSeriesT,
EagerExprT,
EagerSeriesT,
NativeFrameT,
NativeSeriesT,
)
from narwhals._translate import (
ArrowConvertible,
DictConvertible,
FromNative,
NumpyConvertible,
ToNarwhals,
ToNarwhalsT_co,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import (
ValidateBackendVersion,
Version,
_StoresNative,
check_columns_exist,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_sized_multi_index_selector,
is_slice_index,
is_slice_none,
)
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self, TypeAlias
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
from narwhals._compliant.namespace import EagerNamespace
from narwhals._translate import IntoArrowTable
from narwhals._utils import Implementation, _LimitedContext
from narwhals.dataframe import DataFrame
from narwhals.dtypes import DType
from narwhals.exceptions import ColumnNotFoundError
from narwhals.schema import Schema
from narwhals.typing import (
AsofJoinStrategy,
JoinStrategy,
LazyUniqueKeepStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_2DArray,
_SliceIndex,
_SliceName,
)
Incomplete: TypeAlias = Any
__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]
T = TypeVar("T")
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
class CompliantDataFrame(
NumpyConvertible["_2DArray", "_2DArray"],
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
ArrowConvertible["pa.Table", "IntoArrowTable"],
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Sized,
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
_native_frame: NativeFrameT
_implementation: Implementation
_version: Version
def __narwhals_dataframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
def __getitem__(
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
MultiColSelector[CompliantSeriesT],
],
) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
# NOTE: Ignore is to avoid an intermittent false positive
return self.select(*exprs) # pyright: ignore[reportArgumentType]
def _with_version(self, version: Version) -> Self: ...
@property
def native(self) -> NativeFrameT:
return self._native_frame
@property
def columns(self) -> Sequence[str]: ...
@property
def schema(self) -> Mapping[str, DType]: ...
@property
def shape(self) -> tuple[int, int]: ...
def clone(self) -> Self: ...
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def estimated_size(self, unit: SizeUnit) -> int | float: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> DataFrameGroupBy[Self, Any]: ...
def head(self, n: int) -> Self: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
def is_unique(self) -> CompliantSeriesT: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def row(self, index: int) -> tuple[Any, ...]: ...
def rows(
self, *, named: bool
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Table: ...
def to_pandas(self) -> pd.DataFrame: ...
def to_polars(self) -> pl.DataFrame: ...
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
class CompliantLazyFrame(
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
_native_frame: NativeFrameT
_implementation: Implementation
_version: Version
def __narwhals_lazyframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
...
def _with_version(self, version: Version) -> Self: ...
@property
def native(self) -> NativeFrameT:
return self._native_frame
@property
def columns(self) -> Sequence[str]: ...
@property
def schema(self) -> Mapping[str, DType]: ...
def _iter_columns(self) -> Iterator[Any]: ...
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
def head(self, n: int) -> Self: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
result = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
class EagerDataFrame(
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
ValidateBackendVersion,
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __narwhals_namespace__(
self,
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
def to_narwhals(self) -> DataFrame[NativeFrameT]:
return self._version.dataframe(self, level="full")
def _with_native(
self, df: NativeFrameT, *, validate_column_names: bool = True
) -> Self: ...
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
"""Evaluate `expr` and ensure it has a **single** output."""
result: Sequence[EagerSeriesT] = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
# NOTE: Ignore is to avoid an intermittent false positive
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
"""Return list of raw columns.
For eager backends we alias operations at each step.
As a safety precaution, here we can check that the expected result names match those
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
"""
aliases = expr._evaluate_aliases(self)
result = expr(self)
if list(aliases) != (
result_aliases := [s.name for s in result]
): # pragma: no cover
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
raise AssertionError(msg)
return result
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
"""Extract native Series, broadcasting to `len(self)` if necessary."""
...
@staticmethod
def _numpy_column_names(
data: _2DArray, columns: Sequence[str] | None, /
) -> list[str]:
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def _select_multi_index(
self, columns: SizedMultiIndexSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_name(
self, columns: SizedMultiNameSelector[NativeSeriesT]
) -> Self: ...
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
def _select_slice_name(self, columns: _SliceName) -> Self: ...
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
MultiColSelector[EagerSeriesT],
],
) -> Self:
rows, columns = item
compliant = self
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return compliant.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
compliant = compliant._select_slice_index(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_index(columns.native)
else:
compliant = compliant._select_multi_index(columns)
elif isinstance(columns, slice):
compliant = compliant._select_slice_name(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_name(columns.native)
elif is_sequence_like(columns):
compliant = self._select_multi_name(columns)
else:
assert_never(columns)
if not is_slice_none(rows):
if isinstance(rows, int):
compliant = compliant._gather([rows])
elif isinstance(rows, (slice, range)):
compliant = compliant._gather_slice(rows)
elif is_compliant_series(rows):
compliant = compliant._gather(rows.native)
elif is_sized_multi_index_selector(rows):
compliant = compliant._gather(rows)
else:
assert_never(rows)
return compliant
def sink_parquet(self, file: str | Path | BytesIO) -> None:
return self.write_parquet(file)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,178 @@
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantDataFrameT_co,
CompliantExprT_contra,
CompliantFrameT_co,
CompliantLazyFrameAny,
DepthTrackingExprAny,
DepthTrackingExprT_contra,
EagerExprT_contra,
NarwhalsAggregation,
)
from narwhals._utils import is_sequence_of
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
_SameFrameT = TypeVar("_SameFrameT", CompliantDataFrameAny, CompliantLazyFrameAny)
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"]
NativeAggregationT_co = TypeVar(
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
)
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
class CompliantGroupBy(Protocol[CompliantFrameT_co, CompliantExprT_contra]):
_compliant_frame: Any
@property
def compliant(self) -> CompliantFrameT_co:
return self._compliant_frame # type: ignore[no-any-return]
def __init__(
self,
compliant_frame: CompliantFrameT_co,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None: ...
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
class DataFrameGroupBy(
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
Protocol[CompliantDataFrameT_co, CompliantExprT_contra],
):
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
class ParseKeysGroupBy(
CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra],
Protocol[CompliantFrameT_co, CompliantExprT_contra],
):
def _parse_keys(
self,
compliant_frame: _SameFrameT,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
) -> tuple[_SameFrameT, list[str], list[str]]:
if is_sequence_of(keys, str):
keys_str = list(keys)
return compliant_frame, keys_str, keys_str.copy()
else:
return self._parse_expr_keys(compliant_frame, keys=keys)
@staticmethod
def _parse_expr_keys(
compliant_frame: _SameFrameT, keys: Sequence[CompliantExprT_contra]
) -> tuple[_SameFrameT, list[str], list[str]]:
"""Parses key expressions to set up `.agg` operation with correct information.
Since keys are expressions, it's possible to alias any such key to match
other dataframe column names.
In order to match polars behavior and not overwrite columns when evaluating keys:
- We evaluate what the output key names should be, in order to remap temporary column
names to the expected ones, and to exclude those from unnamed expressions in
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
- Create temporary names for evaluated key expressions that are guaranteed to have
no overlap with any existing column name.
- Add these temporary columns to the compliant dataframe.
"""
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
def _temporary_name(key: str) -> str:
# 5 is the length of `__tmp`
key_str = str(key) # pandas allows non-string column names :sob:
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
output_names = compliant_frame._evaluate_aliases(*keys)
safe_keys = [
# multi-output expression cannot have duplicate names, hence it's safe to suffix
key.name.map(_temporary_name)
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
# otherwise it's single named and we can use Expr.alias
else key.alias(_temporary_name(new_name))
for key, new_name in zip(keys, output_names)
]
return (
compliant_frame.with_columns(*safe_keys),
compliant_frame._evaluate_aliases(*safe_keys),
output_names,
)
class DepthTrackingGroupBy(
ParseKeysGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra],
Protocol[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co],
):
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
"""Mapping from `narwhals` to native representation.
Note:
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
"""
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
for expr in exprs:
if not self._is_simple(expr):
name = self.compliant._implementation.name.lower()
msg = (
f"Non-trivial complex aggregation found.\n\n"
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
f"{name!r} table.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)
@classmethod
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
@classmethod
def _remap_expr_name(
cls, name: NarwhalsAggregation | Any, /
) -> NativeAggregationT_co:
"""Replace `name`, with some native representation.
Arguments:
name: Name of a `nw.Expr` aggregation method.
Returns:
A native compatible representation.
"""
return cls._REMAP_AGGS.get(name, name)
@classmethod
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
"""Return the last function name in the chain defined by `expr`."""
return _RE_LEAF_NAME.sub("", expr._function_name)
class EagerGroupBy(
DepthTrackingGroupBy[
CompliantDataFrameT_co, EagerExprT_contra, NativeAggregationT_co
],
DataFrameGroupBy[CompliantDataFrameT_co, EagerExprT_contra],
Protocol[CompliantDataFrameT_co, EagerExprT_contra, NativeAggregationT_co],
): ...

View File

@ -0,0 +1,211 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, overload
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprT,
NativeFrameT,
NativeFrameT_co,
NativeSeriesT,
)
from narwhals._utils import (
exclude_column_names,
get_column_names,
passthrough_column_names,
)
from narwhals.dependencies import is_numpy_array_2d
if TYPE_CHECKING:
from collections.abc import Container, Iterable, Mapping, Sequence
from typing_extensions import TypeAlias
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import (
ConcatMethod,
Into1DArray,
IntoDType,
NonNestedLiteral,
_2DArray,
)
Incomplete: TypeAlias = Any
__all__ = [
"CompliantNamespace",
"DepthTrackingNamespace",
"EagerNamespace",
"LazyNamespace",
]
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
_implementation: Implementation
_version: Version
def all(self) -> CompliantExprT:
return self._expr.from_column_names(get_column_names, context=self)
def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), context=self
)
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names), context=self
)
def nth(self, *column_indices: int) -> CompliantExprT:
return self._expr.from_column_indices(*column_indices, context=self)
def len(self) -> CompliantExprT: ...
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
def all_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def any_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
) -> CompliantFrameT: ...
def when(
self, predicate: CompliantExprT
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
def concat_str(
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
@property
def _expr(self) -> type[CompliantExprT]: ...
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
):
def all(self) -> DepthTrackingExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)
def col(self, *column_names: str) -> DepthTrackingExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)
class LazyNamespace(
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
else: # pragma: no cover
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
@property
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
@overload
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
elif self._series._is_native(data):
return self._series.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
@overload
def from_numpy(
self,
data: _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> EagerDataFrameT: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
) -> EagerDataFrameT | EagerSeriesT:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self)
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def _concat_horizontal(
self, dfs: Sequence[NativeFrameT | Any], /
) -> NativeFrameT: ...
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def concat(
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
) -> EagerDataFrameT:
dfs = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else: # pragma: no cover
raise NotImplementedError
return self._dataframe.from_native(native, context=self)

View File

@ -0,0 +1,311 @@
"""Almost entirely complete, generic `selectors` implementation."""
from __future__ import annotations
import re
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
from narwhals._compliant.expr import CompliantExpr
from narwhals._utils import (
_parse_time_unit_and_time_zone,
dtype_matches_time_unit_and_time_zone,
get_column_names,
is_compliant_dataframe,
)
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
from datetime import timezone
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.expr import NativeExpr
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprAny,
CompliantFrameAny,
CompliantLazyFrameAny,
CompliantSeriesAny,
CompliantSeriesOrNativeExprAny,
EvalNames,
EvalSeries,
)
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
__all__ = [
"CompliantSelector",
"CompliantSelectorNamespace",
"EagerSelectorNamespace",
"LazySelectorNamespace",
]
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
ExprT = TypeVar("ExprT", bound="NativeExpr")
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
SelectorOrExpr: TypeAlias = (
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
)
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
_implementation: Implementation
_version: Version
@classmethod
def from_namespace(cls, context: _LimitedContext, /) -> Self:
obj = cls.__new__(cls)
obj._implementation = context._implementation
obj._version = context._version
return obj
@property
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
def _iter_columns_dtypes(
self, df: FrameT, /
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
yield from zip(self._iter_columns(df), df.columns)
def _is_dtype(
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
return self._selector.from_callables(series, names, context=self)
def by_dtype(
self, dtypes: Collection[DType | type[DType]]
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
return self._selector.from_callables(series, names, context=self)
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
p = re.compile(pattern)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
if (
is_compliant_dataframe(df)
and not self._implementation.is_duckdb()
and not self._implementation.is_ibis()
):
return [df.get_column(col) for col in df.columns if p.search(col)]
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
def names(df: FrameT) -> Sequence[str]:
return [col for col in df.columns if p.search(col)]
return self._selector.from_callables(series, names, context=self)
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
return self._selector.from_callables(series, names, context=self)
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Categorical)
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.String)
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Boolean)
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return list(self._iter_columns(df))
return self._selector.from_callables(series, get_column_names, context=self)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> CompliantSelector[FrameT, SeriesOrExprT]:
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
matches = partial(
dtype_matches_time_unit_and_time_zone,
dtypes=self._version.dtypes,
time_units=time_units,
time_zones=time_zones,
)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if matches(tp)]
return self._selector.from_callables(series, names, context=self)
class EagerSelectorNamespace(
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
):
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
for ser in self._iter_columns(df):
yield ser.name, ser.dtype
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
yield from df.iter_columns()
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
for ser in self._iter_columns(df):
yield ser, ser.dtype
class LazySelectorNamespace(
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
):
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
yield from df.schema.items()
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
yield from df._iter_columns()
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
yield from zip(self._iter_columns(df), df.schema.values())
class CompliantSelector(
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
):
_call: EvalSeries[FrameT, SeriesOrExprT]
_function_name: str
_implementation: Implementation
_version: Version
@classmethod
def from_callables(
cls,
call: EvalSeries[FrameT, SeriesOrExprT],
evaluate_output_names: EvalNames[FrameT],
*,
context: _LimitedContext,
) -> Self:
obj = cls.__new__(cls)
obj._call = call
obj._evaluate_output_names = evaluate_output_names
obj._alias_output_names = None
obj._implementation = context._implementation
obj._version = context._version
return obj
@property
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
return self.__narwhals_namespace__().selectors
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def _is_selector(
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
return isinstance(other, type(self))
@overload
def __sub__(self, other: Self) -> Self: ...
@overload
def __sub__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __sub__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x for x, name in zip(self(df), lhs_names) if name not in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x not in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() - other
@overload
def __or__(self, other: Self) -> Self: ...
@overload
def __or__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __or__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
*(x for x, name in zip(self(df), lhs_names) if name not in rhs_names),
*other(df),
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() | other
@overload
def __and__(self, other: Self) -> Self: ...
@overload
def __and__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __and__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x, name in zip(self(df), lhs_names) if name in rhs_names]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() & other
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self.selectors.all() - self
def _eval_lhs_rhs(
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
) -> tuple[Sequence[str], Sequence[str]]:
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)

View File

@ -0,0 +1,516 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol
from narwhals._compliant.any_namespace import (
CatNamespace,
DateTimeNamespace,
ListNamespace,
StringNamespace,
StructNamespace,
)
from narwhals._compliant.typing import (
CompliantSeriesT_co,
EagerDataFrameAny,
EagerSeriesT_co,
NativeSeriesT,
NativeSeriesT_co,
)
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
_StoresCompliant,
_StoresNative,
is_compliant_series,
is_sized_multi_index_selector,
unstable,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import NotRequired, Self, TypedDict
from narwhals._compliant.dataframe import CompliantDataFrame
from narwhals._compliant.expr import CompliantExpr, EagerExpr
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
ClosedInterval,
FillNullStrategy,
Into1DArray,
IntoDType,
MultiIndexSelector,
NonNestedLiteral,
NumericLiteral,
RankMethod,
RollingInterpolationMethod,
SizedMultiIndexSelector,
TemporalLiteral,
_1DArray,
_SliceIndex,
)
class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
breakpoint: NotRequired[list[float] | _1DArray | list[Any]]
count: NativeSeriesT | _1DArray | _CountsT_co | list[Any]
_CountsT_co = TypeVar("_CountsT_co", bound="Iterable[Any]", covariant=True)
__all__ = [
"CompliantSeries",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesHist",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
]
class CompliantSeries(
NumpyConvertible["_1DArray", "Into1DArray"],
FromIterable,
FromNative[NativeSeriesT],
ToNarwhals["Series[NativeSeriesT]"],
Protocol[NativeSeriesT],
):
_implementation: Implementation
_version: Version
@property
def dtype(self) -> DType: ...
@property
def name(self) -> str: ...
@property
def native(self) -> NativeSeriesT: ...
def __narwhals_series__(self) -> Self:
return self
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
def __native_namespace__(self) -> ModuleType: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
def __contains__(self, other: Any) -> bool: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
def __iter__(self) -> Iterator[Any]: ...
def __len__(self) -> int:
return len(self.native)
def _with_native(self, series: Any) -> Self: ...
def _with_version(self, version: Version) -> Self: ...
def _to_expr(self) -> CompliantExpr[Any, Self]: ...
@classmethod
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_iterable(
cls,
data: Iterable[Any],
/,
*,
context: _LimitedContext,
name: str = "",
dtype: IntoDType | None = None,
) -> Self: ...
def to_narwhals(self) -> Series[NativeSeriesT]:
return self._version.series(self, level="full")
# Operators
def __add__(self, other: Any) -> Self: ...
def __and__(self, other: Any) -> Self: ...
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
def __floordiv__(self, other: Any) -> Self: ...
def __ge__(self, other: Any) -> Self: ...
def __gt__(self, other: Any) -> Self: ...
def __invert__(self) -> Self: ...
def __le__(self, other: Any) -> Self: ...
def __lt__(self, other: Any) -> Self: ...
def __mod__(self, other: Any) -> Self: ...
def __mul__(self, other: Any) -> Self: ...
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
def __or__(self, other: Any) -> Self: ...
def __pow__(self, other: Any) -> Self: ...
def __radd__(self, other: Any) -> Self: ...
def __rand__(self, other: Any) -> Self: ...
def __rfloordiv__(self, other: Any) -> Self: ...
def __rmod__(self, other: Any) -> Self: ...
def __rmul__(self, other: Any) -> Self: ...
def __ror__(self, other: Any) -> Self: ...
def __rpow__(self, other: Any) -> Self: ...
def __rsub__(self, other: Any) -> Self: ...
def __rtruediv__(self, other: Any) -> Self: ...
def __sub__(self, other: Any) -> Self: ...
def __truediv__(self, other: Any) -> Self: ...
def abs(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def all(self) -> bool: ...
def any(self) -> bool: ...
def arg_max(self) -> int: ...
def arg_min(self) -> int: ...
def arg_true(self) -> Self: ...
def cast(self, dtype: IntoDType) -> Self: ...
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...
def count(self) -> int: ...
def cum_count(self, *, reverse: bool) -> Self: ...
def cum_max(self, *, reverse: bool) -> Self: ...
def cum_min(self, *, reverse: bool) -> Self: ...
def cum_prod(self, *, reverse: bool) -> Self: ...
def cum_sum(self, *, reverse: bool) -> Self: ...
def diff(self) -> Self: ...
def drop_nulls(self) -> Self: ...
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self: ...
def exp(self) -> Self: ...
def sqrt(self) -> Self: ...
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self: ...
def filter(self, predicate: Any) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def head(self, n: int) -> Self: ...
def is_between(
self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval
) -> Self: ...
def is_finite(self) -> Self: ...
def is_first_distinct(self) -> Self: ...
def is_in(self, other: Any) -> Self: ...
def is_last_distinct(self) -> Self: ...
def is_nan(self) -> Self: ...
def is_null(self) -> Self: ...
def is_sorted(self, *, descending: bool) -> bool: ...
def is_unique(self) -> Self: ...
def item(self, index: int | None) -> Any: ...
def kurtosis(self) -> float | None: ...
def len(self) -> int: ...
def log(self, base: float) -> Self: ...
def max(self) -> Any: ...
def mean(self) -> float: ...
def median(self) -> float: ...
def min(self) -> Any: ...
def mode(self) -> Self: ...
def n_unique(self) -> int: ...
def null_count(self) -> int: ...
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> float: ...
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self: ...
def rolling_mean(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def rolling_sum(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def round(self, decimals: int) -> Self: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
def shift(self, n: int) -> Self: ...
def skew(self) -> float | None: ...
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
def std(self, *, ddof: int) -> float: ...
def sum(self) -> float: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Array[Any]: ...
def to_dummies(
self, *, separator: str, drop_first: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_list(self) -> list[Any]: ...
def to_pandas(self) -> pd.Series[Any]: ...
def to_polars(self) -> pl.Series: ...
def unique(self, *, maintain_order: bool) -> Self: ...
def value_counts(
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def var(self, *, ddof: int) -> float: ...
def zip_with(self, mask: Any, other: Any) -> Self: ...
@unstable
def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
) -> CompliantDataFrame[Self, Any, Any, Any]:
"""`Series.hist(bins=..., bin_count=None)`."""
...
@unstable
def hist_from_bin_count(
self, bin_count: int, *, include_breakpoint: bool
) -> CompliantDataFrame[Self, Any, Any, Any]:
"""`Series.hist(bins=None, bin_count=...)`."""
...
@property
def str(self) -> Any: ...
@property
def dt(self) -> Any: ...
@property
def cat(self) -> Any: ...
@property
def list(self) -> Any: ...
@property
def struct(self) -> Any: ...
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
_native_series: Any
_implementation: Implementation
_version: Version
_broadcast: bool
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@classmethod
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
"""Ensure all of `series` have the same length (and index if `pandas`).
Scalars get broadcasted to the full length of the longest Series.
This is useful when you need to construct a full Series anyway, such as:
DataFrame.select(...)
It should not be used in binary operations, such as:
nw.col("a") - nw.col("a").mean()
because then it's more efficient to extract the right-hand-side's single element as a scalar.
"""
...
def _from_scalar(self, value: Any) -> Self:
return self.from_iterable([value], name=self.name, context=self)
def _with_native(
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
) -> Self:
"""Return a new `CompliantSeries`, wrapping the native `series`.
In cases when operations are known to not affect whether a result should
be broadcast, we can pass `preserve_broadcast=True`.
Set this with care - it should only be set for unary expressions which don't
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
set it, you probably don't need it.
"""
...
def __narwhals_namespace__(
self,
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
def _to_expr(self) -> EagerExpr[Any, Any]:
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
if isinstance(item, (slice, range)):
return self._gather_slice(item)
elif is_compliant_series(item):
return self._gather(item.native)
elif is_sized_multi_index_selector(item):
return self._gather(item)
else:
assert_never(item)
@property
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
@property
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
@property
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
@property
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
@property
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
class _SeriesNamespace( # type: ignore[misc]
_StoresCompliant[CompliantSeriesT_co],
_StoresNative[NativeSeriesT_co],
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
):
_compliant_series: CompliantSeriesT_co
@property
def compliant(self) -> CompliantSeriesT_co:
return self._compliant_series
@property
def implementation(self) -> Implementation:
return self.compliant._implementation
@property
def backend_version(self) -> tuple[int, ...]:
return self.implementation._backend_version()
@property
def version(self) -> Version:
return self.compliant._version
@property
def native(self) -> NativeSeriesT_co:
return self._compliant_series.native # type: ignore[no-any-return]
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
return self.compliant._with_native(series)
class EagerSeriesNamespace(
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
Generic[EagerSeriesT_co, NativeSeriesT_co],
):
_compliant_series: EagerSeriesT_co
def __init__(self, series: EagerSeriesT_co, /) -> None:
self._compliant_series = series
class EagerSeriesCatNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
CatNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
DateTimeNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesListNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
ListNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStringNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StringNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStructNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StructNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesHist(Protocol[NativeSeriesT, _CountsT_co]):
_series: EagerSeries[NativeSeriesT]
_breakpoint: bool
_data: HistData[NativeSeriesT, _CountsT_co]
@property
def native(self) -> NativeSeriesT:
return self._series.native
@classmethod
def from_series(
cls, series: EagerSeries[NativeSeriesT], *, include_breakpoint: bool
) -> Self:
obj = cls.__new__(cls)
obj._series = series
obj._breakpoint = include_breakpoint
return obj
def to_frame(self) -> EagerDataFrameAny: ...
def _linear_space( # NOTE: Roughly `pl.linear_space`
self,
start: float,
end: float,
num_samples: int,
*,
closed: Literal["both", "none"] = "both",
) -> _1DArray: ...
# NOTE: *Could* be handled at narwhals-level
def is_empty_series(self) -> bool: ...
# NOTE: **Should** be handled at narwhals-level
def data_empty(self) -> HistData[NativeSeriesT, _CountsT_co]:
return {"breakpoint": [], "count": []} if self._breakpoint else {"count": []}
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
def series_empty(
self, arg: int | list[float], /
) -> HistData[NativeSeriesT, _CountsT_co]: ...
def with_bins(self, bins: list[float], /) -> Self:
if len(bins) <= 1:
self._data = self.data_empty()
elif self.is_empty_series():
self._data = self.series_empty(bins)
else:
self._data = self._calculate_hist(bins)
return self
def with_bin_count(self, bin_count: int, /) -> Self:
if bin_count == 0:
self._data = self.data_empty()
elif self.is_empty_series():
self._data = self.series_empty(bin_count)
else:
self._data = self._calculate_hist(self._calculate_bins(bin_count))
return self
def _calculate_breakpoint(self, arg: int | list[float], /) -> list[float] | _1DArray:
bins = self._linear_space(0, 1, arg + 1) if isinstance(arg, int) else arg
return bins[1:]
def _calculate_bins(self, bin_count: int) -> _1DArray: ...
def _calculate_hist(
self, bins: list[float] | _1DArray
) -> HistData[NativeSeriesT, _CountsT_co]: ...

View File

@ -0,0 +1,196 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
LazyExpr,
NativeExpr,
)
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
from narwhals._compliant.series import CompliantSeries, EagerSeries
from narwhals._compliant.window import WindowInputs
from narwhals.typing import (
FillNullStrategy,
NativeFrame,
NativeSeries,
RankMethod,
RollingInterpolationMethod,
)
class ScalarKwargs(TypedDict, total=False):
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
adjust: bool
alpha: float | None
center: int
com: float | None
ddof: int
descending: bool
half_life: float | None
ignore_nulls: bool
interpolation: RollingInterpolationMethod
limit: int | None
method: RankMethod
min_samples: int
n: int
quantile: float
reverse: bool
span: float | None
strategy: FillNullStrategy | None
window_size: int
__all__ = [
"AliasName",
"AliasNames",
"CompliantDataFrameT",
"CompliantFrameT",
"CompliantLazyFrameT",
"CompliantSeriesT",
"EvalNames",
"EvalSeries",
"IntoCompliantExpr",
"NarwhalsAggregation",
"NativeFrameT_co",
"NativeSeriesT_co",
]
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
CompliantFrameAny: TypeAlias = "CompliantDataFrameAny | CompliantLazyFrameAny"
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
NativeSeriesT_contra = TypeVar(
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
)
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
NativeFrameT_contra = TypeVar(
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
CompliantExprT_contra = TypeVar(
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
)
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
)
CompliantSeriesOrNativeExprT = TypeVar(
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
)
CompliantSeriesOrNativeExprT_co = TypeVar(
"CompliantSeriesOrNativeExprT_co",
bound=CompliantSeriesOrNativeExprAny,
covariant=True,
)
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
CompliantFrameT_co = TypeVar(
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
)
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
CompliantDataFrameT_co = TypeVar(
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
)
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
CompliantLazyFrameT_co = TypeVar(
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
)
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
CompliantNamespaceT_co = TypeVar(
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
)
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
DepthTrackingExprT_contra = TypeVar(
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
)
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
"""A function aliasing a *sequence* of column names."""
AliasName: TypeAlias = Callable[[str], str]
"""A function aliasing a *single* column name."""
EvalSeries: TypeAlias = Callable[
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
]
"""A function from a `Frame` to a sequence of `Series`*.
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
"""
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
WindowFunction: TypeAlias = (
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
)
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
NarwhalsAggregation: TypeAlias = Literal[
"sum",
"mean",
"median",
"max",
"min",
"std",
"var",
"len",
"n_unique",
"count",
"quantile",
"all",
"any",
]
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
Be sure to update me if you're working on one of these:
- https://github.com/narwhals-dev/narwhals/issues/981
- https://github.com/narwhals-dev/narwhals/issues/2385
- https://github.com/narwhals-dev/narwhals/issues/2484
- https://github.com/narwhals-dev/narwhals/issues/2526
- https://github.com/narwhals-dev/narwhals/issues/2660
"""

View File

@ -0,0 +1,130 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import (
CompliantExprAny,
CompliantFrameAny,
CompliantSeriesOrNativeExprAny,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprAny,
NativeSeriesT,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self, TypeAlias
from narwhals._compliant.typing import EvalSeries
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.typing import NonNestedLiteral
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
Scalar: TypeAlias = Any
"""A native literal value."""
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
"""Anything that is convertible into a `CompliantExpr`."""
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
_implementation: Implementation
_version: Version
@property
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
def then(
self, value: IntoExpr[SeriesT, ExprT], /
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
return self._then.from_when(self, value)
@classmethod
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._version = context._version
return obj
WhenT_contra = TypeVar(
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
)
class CompliantThen(
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
):
_call: EvalSeries[FrameT, SeriesT]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_implementation: Implementation
_version: Version
@classmethod
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._when_value = when
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._version = when._version
return obj
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
self._when_value._otherwise_value = otherwise
return cast("ExprT", self)
class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
/,
) -> NativeSeriesT: ...
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
when: EagerSeriesT = self._condition(df)[0]
then: EagerSeriesT
align = when._align_full_broadcast
if is_expr(self._then_value):
then = self._then_value(df)[0]
else:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = self._otherwise_value(df)[0]
when, then, otherwise = align(when, then, otherwise)
result = self._if_then_else(when.native, then.native, otherwise.native)
else:
when, then = align(when, then)
result = self._if_then_else(when.native, then.native, self._otherwise_value)
return [then._with_native(result)]

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic
from narwhals._compliant.typing import NativeExprT_co
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["WindowInputs"]
class WindowInputs(Generic[NativeExprT_co]):
__slots__ = ("order_by", "partition_by")
def __init__(
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
) -> None:
self.partition_by = partition_by
self.order_by = order_by

View File

@ -0,0 +1,30 @@
from __future__ import annotations
import datetime as dt
# Temporal (from `polars._utils.constants`)
SECONDS_PER_DAY = 86_400
SECONDS_PER_MINUTE = 60
NS_PER_MINUTE = 60_000_000_000
"""Nanoseconds (`[ns]`) per minute."""
US_PER_MINUTE = 60_000_000
"""Microseconds (`[μs]`) per minute."""
MS_PER_MINUTE = 60_000
"""Milliseconds (`[ms]`) per minute."""
NS_PER_SECOND = 1_000_000_000
"""Nanoseconds (`[ns]`) per second (`[s]`)."""
US_PER_SECOND = 1_000_000
"""Microseconds (`[μs]`) per second (`[s]`)."""
MS_PER_SECOND = 1_000
"""Milliseconds (`[ms]`) per second (`[s]`)."""
NS_PER_MICROSECOND = 1_000
"""Nanoseconds (`[ns]`) per microsecond (`[μs]`)."""
NS_PER_MILLISECOND = 1_000_000
"""Nanoseconds (`[ns]`) per millisecond (`[ms]`).
From [polars](https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-time/src/chunkedarray/duration.rs#L7).
"""
EPOCH_YEAR = 1970
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""
EPOCH = dt.datetime(EPOCH_YEAR, 1, 1).replace(tzinfo=None)
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""

View File

@ -0,0 +1,483 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import dask.dataframe as dd
from narwhals._dask.utils import add_row_index, evaluate_exprs
from narwhals._expression_parsing import ExprKind
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
from narwhals._typing_compat import assert_never
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
_remap_full_join_keys,
check_column_names_are_unique,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
)
from narwhals.typing import CompliantLazyFrame
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
import dask.dataframe.dask_expr as dx
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._dask.expr import DaskExpr
from narwhals._dask.group_by import DaskLazyGroupBy
from narwhals._dask.namespace import DaskNamespace
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
Incomplete: TypeAlias = "Any"
"""Using `_pandas_like` utils with `_dask`.
Typing this correctly will complicate the `_pandas_like`-side.
Very low priority until `dask` adds typing.
"""
class DaskLazyFrame(
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
ValidateBackendVersion,
):
_implementation = Implementation.DASK
def __init__(
self,
native_dataframe: dd.DataFrame,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame: dd.DataFrame = native_dataframe
self._version = version
self._cached_schema: dict[str, DType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version:
self._validate_backend_version()
@staticmethod
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
return isinstance(obj, dd.DataFrame)
@classmethod
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
return self._version.lazyframe(self, level="lazy")
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.DASK:
return self._implementation.to_native_namespace()
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def __narwhals_namespace__(self) -> DaskNamespace:
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
def _with_native(self, df: Any) -> Self:
return self.__class__(df, version=self._version)
def _iter_columns(self) -> Iterator[dx.Series]:
for _col, ser in self.native.items(): # noqa: PERF102
yield ser
def with_columns(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
return self._with_native(self.native.assign(**dict(new_series)))
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny:
result = self.native.compute(**kwargs)
if backend is None or backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
pl.from_pandas(result),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
pa.Table.from_pandas(result),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_schema is not None
else self.native.columns.tolist()
)
return self._cached_columns
def filter(self, predicate: DaskExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = predicate(self)[0]
return self._with_native(self.native.loc[mask])
def simple_select(self, *column_names: str) -> Self:
df: Incomplete = self.native
native = select_columns_by_name(df, list(column_names), self._implementation)
return self._with_native(native)
def aggregate(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
return self._with_native(df)
def select(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df: Incomplete = self.native
df = select_columns_by_name(
df.assign(**dict(new_series)),
[s[0] for s in new_series],
self._implementation,
)
return self._with_native(df)
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._with_native(self.native.dropna())
plx = self.__narwhals_namespace__()
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
return self.filter(mask)
@property
def schema(self) -> dict[str, DType]:
if self._cached_schema is None:
native_dtypes = self.native.dtypes
self._cached_schema = {
col: native_to_narwhals_dtype(
native_dtypes[col], self._version, self._implementation
)
for col in self.native.columns
}
return self._cached_schema
def collect_schema(self) -> dict[str, DType]:
return self.schema
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(columns=to_drop))
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
# Implementation is based on the following StackOverflow reply:
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
if order_by is None:
return self._with_native(add_row_index(self.native, name))
else:
plx = self.__narwhals_namespace__()
columns = self.columns
const_expr = (
plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
)
row_index_expr = (
plx.col(name)
.cum_sum(reverse=False)
.over(partition_by=[], order_by=order_by)
- 1
)
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
def rename(self, mapping: Mapping[str, str]) -> Self:
return self._with_native(self.native.rename(columns=mapping))
def head(self, n: int) -> Self:
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset and (error := self._check_columns_exist(subset)):
raise error
if keep == "none":
subset = subset or self.columns
token = generate_temporary_column_name(n_bytes=8, columns=subset)
ser = self.native.groupby(subset).size().rename(token)
ser = ser[ser == 1]
unique = ser.reset_index().drop(columns=token)
result = self.native.merge(unique, on=subset, how="inner")
else:
mapped_keep = {"any": "first"}.get(keep, keep)
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
return self._with_native(result)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
position = "last" if nulls_last else "first"
return self._with_native(
self.native.sort_values(list(by), ascending=ascending, na_position=position)
)
def _join_inner(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
return self.native.merge(
other.native,
left_on=left_on,
right_on=right_on,
how="inner",
suffixes=("", suffix),
)
def _join_left(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
result_native = self.native.merge(
other.native,
how="left",
left_on=left_on,
right_on=right_on,
suffixes=("", suffix),
)
extra = [
right_key if right_key not in self.columns else f"{right_key}{suffix}"
for left_key, right_key in zip(left_on, right_on)
if right_key != left_key
]
return result_native.drop(columns=extra)
def _join_full(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
# dask does not retain keys post-join
# we must append the suffix to each key before-hand
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
other_native = other.native.rename(columns=right_on_mapper)
check_column_names_are_unique(other_native.columns)
right_suffixed = list(right_on_mapper.values())
return self.native.merge(
other_native,
left_on=left_on,
right_on=right_suffixed,
how="outer",
suffixes=("", suffix),
)
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
key_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
)
return (
self.native.assign(**{key_token: 0})
.merge(
other.native.assign(**{key_token: 0}),
how="inner",
left_on=key_token,
right_on=key_token,
suffixes=("", suffix),
)
.drop(columns=key_token)
)
def _join_semi(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
other_native = self._join_filter_rename(
other=other,
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
return self.native.merge(
other_native, how="inner", left_on=left_on, right_on=left_on
)
def _join_anti(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
indicator_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
)
other_native = self._join_filter_rename(
other=other,
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
df = self.native.merge(
other_native,
how="left",
indicator=indicator_token, # pyright: ignore[reportArgumentType]
left_on=left_on,
right_on=left_on,
)
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
def _join_filter_rename(
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
) -> dd.DataFrame:
"""Helper function to avoid creating extra columns and row duplication.
Used in `"anti"` and `"semi`" join's.
Notice that a native object is returned.
"""
other_native: Incomplete = other.native
# rename to avoid creating extra columns in join
return (
select_columns_by_name(other_native, columns_to_select, self._implementation)
.rename(columns=columns_mapping)
.drop_duplicates()
)
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
if how == "cross":
result = self._join_cross(other=other, suffix=suffix)
elif left_on is None or right_on is None: # pragma: no cover
raise ValueError(left_on, right_on)
elif how == "inner":
result = self._join_inner(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
elif how == "anti":
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
elif how == "semi":
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
elif how == "left":
result = self._join_left(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
elif how == "full":
result = self._join_full(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
else:
assert_never(how)
return self._with_native(result)
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self:
plx = self.__native_namespace__()
return self._with_native(
plx.merge_asof(
self.native,
other.native,
left_on=left_on,
right_on=right_on,
left_by=by_left,
right_by=by_right,
direction=strategy,
suffixes=("", suffix),
)
)
def group_by(
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
) -> DaskLazyGroupBy:
from narwhals._dask.group_by import DaskLazyGroupBy
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def tail(self, n: int) -> Self: # pragma: no cover
native_frame = self.native
n_partitions = native_frame.npartitions
if n_partitions == 1:
return self._with_native(self.native.tail(n=n, compute=False))
else:
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
def gather_every(self, n: int, offset: int) -> Self:
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
plx = self.__narwhals_namespace__()
return (
self.with_row_index(row_index_token, order_by=None)
.filter(
(plx.col(row_index_token) >= offset)
& ((plx.col(row_index_token) - offset) % n == 0)
)
.drop([row_index_token], strict=False)
)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
var_name=variable_name,
value_name=value_name,
)
)
def sink_parquet(self, file: str | Path | BytesIO) -> None:
self.native.to_parquet(file)
explode = not_implemented()

View File

@ -0,0 +1,692 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal
from narwhals._compliant import DepthTrackingExpr, LazyExpr
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
from narwhals._dask.expr_str import DaskExprStringNamespace
from narwhals._dask.utils import (
add_row_index,
maybe_evaluate_expr,
narwhals_to_native_dtype,
)
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
not_implemented,
)
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Sequence
import dask.dataframe.dask_expr as dx
from typing_extensions import Self
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
NonNestedLiteral,
NumericLiteral,
RollingInterpolationMethod,
TemporalLiteral,
)
class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"],
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
):
_implementation: Implementation = Implementation.DASK
def __init__(
self,
call: EvalSeries[DaskLazyFrame, dx.Series],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[DaskLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return self._call(df)
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
# that raised a KeyError for result[0] during collection.
return [result.loc[0][0] for result in self(df)]
return self.__class__(
func,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[DaskLazyFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [
df._native_frame[column_name]
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df.native.iloc[:, i] for i in column_indices]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def _with_callable(
self,
# First argument to `call` should be `dx.Series`
call: Callable[..., dx.Series],
/,
expr_name: str = "",
scalar_kwargs: ScalarKwargs | None = None,
**expressifiable_args: Self | Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
native_results: list[dx.Series] = []
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate_expr(df, value)
for key, value in expressifiable_args.items()
}
for native_series in native_series_list:
result_native = call(native_series, **other_native_series)
native_results.append(result_native)
return native_results
return self.__class__(
func,
depth=self._depth + 1,
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=scalar_kwargs,
)
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
current_alias_output_names = self._alias_output_names
alias_output_names = (
None
if func is None
else func
if current_alias_output_names is None
else lambda output_names: func(current_alias_output_names(output_names))
)
return type(self)(
call=self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
def __add__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__add__(other), "__add__", other=other
)
def __sub__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__sub__(other), "__sub__", other=other
)
def __rsub__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other - expr, "__rsub__", other=other
).alias("literal")
def __mul__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__mul__(other), "__mul__", other=other
)
def __truediv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__truediv__(other), "__truediv__", other=other
)
def __rtruediv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other / expr, "__rtruediv__", other=other
).alias("literal")
def __floordiv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other
)
def __rfloordiv__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other // expr, "__rfloordiv__", other=other
).alias("literal")
def __pow__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__pow__(other), "__pow__", other=other
)
def __rpow__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other**expr, "__rpow__", other=other
).alias("literal")
def __mod__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__mod__(other), "__mod__", other=other
)
def __rmod__(self, other: Any) -> Self:
return self._with_callable(
lambda expr, other: other % expr, "__rmod__", other=other
).alias("literal")
def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override]
return self._with_callable(
lambda expr, other: expr.__eq__(other), "__eq__", other=other
)
def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override]
return self._with_callable(
lambda expr, other: expr.__ne__(other), "__ne__", other=other
)
def __ge__(self, other: DaskExpr | Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__ge__(other), "__ge__", other=other
)
def __gt__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__gt__(other), "__gt__", other=other
)
def __le__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__le__(other), "__le__", other=other
)
def __lt__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__lt__(other), "__lt__", other=other
)
def __and__(self, other: DaskExpr | Any) -> Self:
return self._with_callable(
lambda expr, other: expr.__and__(other), "__and__", other=other
)
def __or__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda expr, other: expr.__or__(other), "__or__", other=other
)
def __invert__(self) -> Self:
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
def mean(self) -> Self:
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
def median(self) -> Self:
from narwhals.exceptions import InvalidOperationError
def func(s: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
if not dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return s.median_approximate().to_series()
return self._with_callable(func, "median")
def min(self) -> Self:
return self._with_callable(lambda expr: expr.min().to_series(), "min")
def max(self) -> Self:
return self._with_callable(lambda expr: expr.max().to_series(), "max")
def std(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.std(ddof=ddof).to_series(),
"std",
scalar_kwargs={"ddof": ddof},
)
def var(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.var(ddof=ddof).to_series(),
"var",
scalar_kwargs={"ddof": ddof},
)
def skew(self) -> Self:
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
def kurtosis(self) -> Self:
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
def shift(self, n: int) -> Self:
return self._with_callable(lambda expr: expr.shift(n), "shift")
def cum_sum(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
# https://github.com/dask/dask/issues/11802
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
def cum_count(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
)
def cum_min(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
def cum_max(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
def cum_prod(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).sum(),
"rolling_sum",
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).mean(),
"rolling_mean",
)
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).var(),
"rolling_var",
)
else:
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
raise NotImplementedError(msg)
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).std(),
"rolling_std",
)
else:
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
raise NotImplementedError(msg)
def sum(self) -> Self:
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
def count(self) -> Self:
return self._with_callable(lambda expr: expr.count().to_series(), "count")
def round(self, decimals: int) -> Self:
return self._with_callable(lambda expr: expr.round(decimals), "round")
def unique(self) -> Self:
return self._with_callable(lambda expr: expr.unique(), "unique")
def drop_nulls(self) -> Self:
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
def abs(self) -> Self:
return self._with_callable(lambda expr: expr.abs(), "abs")
def all(self) -> Self:
return self._with_callable(
lambda expr: expr.all(
axis=None, skipna=True, split_every=False, out=None
).to_series(),
"all",
)
def any(self) -> Self:
return self._with_callable(
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
"any",
)
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
def func(expr: dx.Series) -> dx.Series:
if value is not None:
res_ser = expr.fillna(value)
else:
res_ser = (
expr.ffill(limit=limit)
if strategy == "forward"
else expr.bfill(limit=limit)
)
return res_ser
return self._with_callable(func, "fillna")
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self:
return self._with_callable(
lambda expr, lower_bound, upper_bound: expr.clip(
lower=lower_bound, upper=upper_bound
),
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
)
def diff(self) -> Self:
return self._with_callable(lambda expr: expr.diff(), "diff")
def n_unique(self) -> Self:
return self._with_callable(
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
)
def is_null(self) -> Self:
return self._with_callable(lambda expr: expr.isna(), "is_null")
def is_nan(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(
expr.dtype, self._version, self._implementation
)
if dtype.is_numeric():
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
raise InvalidOperationError(msg)
return self._with_callable(func, "is_null")
def len(self) -> Self:
return self._with_callable(lambda expr: expr.size.to_series(), "len")
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
if interpolation == "linear":
def func(expr: dx.Series, quantile: float) -> dx.Series:
if expr.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
return expr.quantile(
q=quantile, method="dask"
).to_series() # pragma: no cover
return self._with_callable(func, "quantile", quantile=quantile)
else:
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
raise NotImplementedError(msg)
def is_first_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
return frame[col_token].isin(first_distinct_index)
return self._with_callable(func, "is_first_distinct")
def is_last_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
return frame[col_token].isin(last_distinct_index)
return self._with_callable(func, "is_last_distinct")
def is_unique(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
return (
expr.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
== 1
)
return self._with_callable(func, "is_unique")
def is_in(self, other: Any) -> Self:
return self._with_callable(lambda expr: expr.isin(other), "is_in")
def null_count(self) -> Self:
return self._with_callable(
lambda expr: expr.isna().sum().to_series(), "null_count"
)
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
# pandas is a required dependency of dask so it's safe to import this
from narwhals._pandas_like.group_by import PandasLikeGroupBy
if not partition_by:
assert order_by # noqa: S101
# This is something like `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
return self(df.sort(*order_by, descending=False, nulls_last=False))
elif not self._is_elementary(): # pragma: no cover
msg = (
"Only elementary expressions are supported for `.over` in dask.\n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
)
raise NotImplementedError(msg)
elif order_by:
# Wrong results https://github.com/dask/dask/issues/11806.
msg = "`over` with `order_by` is not yet supported in Dask."
raise NotImplementedError(msg)
else:
function_name = PandasLikeGroupBy._leaf_name(self)
try:
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
except KeyError:
# window functions are unsupported: https://github.com/dask/dask/issues/11806
msg = (
f"Unsupported function: {function_name} in `over` context.\n\n"
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
)
raise NotImplementedError(msg) from None
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
with warnings.catch_warnings():
# https://github.com/dask/dask/issues/11804
warnings.filterwarnings(
"ignore",
message=".*`meta` is not specified",
category=UserWarning,
)
grouped = df.native.groupby(partition_by)
if dask_function_name == "size":
if len(output_names) != 1: # pragma: no cover
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform(
dask_function_name, **self._scalar_kwargs
).to_frame(output_names[0])
else:
res_native = grouped[list(output_names)].transform(
dask_function_name, **self._scalar_kwargs
)
result_frame = df._with_native(
res_native.rename(columns=dict(zip(output_names, aliases)))
).native
return [result_frame[name] for name in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
return self._with_callable(func, "cast")
def is_finite(self) -> Self:
import dask.array as da
return self._with_callable(da.isfinite, "is_finite")
def log(self, base: float) -> Self:
import dask.array as da
def _log(expr: dx.Series) -> dx.Series:
return da.log(expr) / da.log(base)
return self._with_callable(_log, "log")
def exp(self) -> Self:
import dask.array as da
return self._with_callable(da.exp, "exp")
def sqrt(self) -> Self:
import dask.array as da
return self._with_callable(da.sqrt, "sqrt")
@property
def str(self) -> DaskExprStringNamespace:
return DaskExprStringNamespace(self)
@property
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)
arg_max: not_implemented = not_implemented()
arg_min: not_implemented = not_implemented()
arg_true: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
gather_every: not_implemented = not_implemented()
head: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
mode: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
rank: not_implemented = not_implemented()
replace_strict: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
tail: not_implemented = not_implemented()
# namespaces
list: not_implemented = not_implemented() # type: ignore[assignment]
cat: not_implemented = not_implemented() # type: ignore[assignment]
struct: not_implemented = not_implemented() # type: ignore[assignment]

View File

@ -0,0 +1,176 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import DateTimeNamespace
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
from narwhals._duration import Interval
from narwhals._pandas_like.utils import (
ALIAS_DICT,
calculate_timestamp_date,
calculate_timestamp_datetime,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from narwhals._dask.expr import DaskExpr
from narwhals.typing import TimeUnit
class DaskExprDateTimeNamespace(
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
):
def date(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
def year(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
def month(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
def day(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
def hour(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
def minute(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
def second(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
def millisecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond // 1000, "millisecond"
)
def microsecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond, "microsecond"
)
def nanosecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
)
def ordinal_day(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.dayofyear, "ordinal_day"
)
def weekday(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
"weekday",
)
def to_string(self, format: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
"strftime",
format=format,
)
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
if time_zone is not None
else expr.dt.tz_localize(None),
"tz_localize",
time_zone=time_zone,
)
def convert_time_zone(self, time_zone: str) -> DaskExpr:
def func(s: dx.Series, time_zone: str) -> dx.Series:
dtype = native_to_narwhals_dtype(
s.dtype, self.compliant._version, Implementation.DASK
)
if dtype.time_zone is None: # type: ignore[attr-defined]
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
else:
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
dtype = native_to_narwhals_dtype(
s.dtype, self.compliant._version, Implementation.DASK
)
is_pyarrow_dtype = "pyarrow" in str(dtype)
mask_na = s.isna()
dtypes = self.compliant._version.dtypes
if dtype == dtypes.Date:
# Date is only supported in pandas dtypes if pyarrow-backed
s_cast = s.astype("Int32[pyarrow]")
result = calculate_timestamp_date(s_cast, time_unit)
elif isinstance(dtype, dtypes.Datetime):
original_time_unit = dtype.time_unit
s_cast = (
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
)
result = calculate_timestamp_datetime(
s_cast, original_time_unit, time_unit
)
else:
msg = "Input should be either of Date or Datetime type"
raise TypeError(msg)
return result.where(~mask_na) # pyright: ignore[reportReturnType]
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
def total_minutes(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
)
def total_seconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
)
def total_milliseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
"total_milliseconds",
)
def total_microseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
"total_microseconds",
)
def total_nanoseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
)
def truncate(self, every: str) -> DaskExpr:
interval = Interval.parse(every)
unit = interval.unit
if unit in {"mo", "q", "y"}:
msg = f"Truncating to {unit} is not yet supported for dask."
raise NotImplementedError(msg)
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
def offset_by(self, by: str) -> DaskExpr:
def func(s: dx.Series, by: str) -> dx.Series:
interval = Interval.parse_no_constraints(by)
unit = interval.unit
if unit in {"y", "q", "mo", "d", "ns"}:
msg = f"Offsetting by {unit} is not yet supported for dask."
raise NotImplementedError(msg)
offset = interval.to_timedelta()
return s.add(offset)
return self.compliant._with_callable(func, "offset_by", by=by)

Some files were not shown because too many files have changed in this diff Show More