some new features
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
483
.venv/lib/python3.12/site-packages/narwhals/_dask/dataframe.py
Normal file
483
.venv/lib/python3.12/site-packages/narwhals/_dask/dataframe.py
Normal file
@ -0,0 +1,483 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._dask.utils import add_row_index, evaluate_exprs
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
_remap_full_join_keys,
|
||||
check_column_names_are_unique,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
)
|
||||
from narwhals.typing import CompliantLazyFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
Incomplete: TypeAlias = "Any"
|
||||
"""Using `_pandas_like` utils with `_dask`.
|
||||
|
||||
Typing this correctly will complicate the `_pandas_like`-side.
|
||||
Very low priority until `dask` adds typing.
|
||||
"""
|
||||
|
||||
|
||||
class DaskLazyFrame(
|
||||
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: dd.DataFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: dd.DataFrame = native_dataframe
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
|
||||
return isinstance(obj, dd.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.DASK:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace:
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: Any) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _iter_columns(self) -> Iterator[dx.Series]:
|
||||
for _col, ser in self.native.items(): # noqa: PERF102
|
||||
yield ser
|
||||
|
||||
def with_columns(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
return self._with_native(self.native.assign(**dict(new_series)))
|
||||
|
||||
def collect(
|
||||
self, backend: Implementation | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
result = self.native.compute(**kwargs)
|
||||
|
||||
if backend is None or backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result,
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
pl.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
pa.Table.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else self.native.columns.tolist()
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def filter(self, predicate: DaskExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = predicate(self)[0]
|
||||
return self._with_native(self.native.loc[mask])
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
df: Incomplete = self.native
|
||||
native = select_columns_by_name(df, list(column_names), self._implementation)
|
||||
return self._with_native(native)
|
||||
|
||||
def aggregate(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
|
||||
return self._with_native(df)
|
||||
|
||||
def select(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df: Incomplete = self.native
|
||||
df = select_columns_by_name(
|
||||
df.assign(**dict(new_series)),
|
||||
[s[0] for s in new_series],
|
||||
self._implementation,
|
||||
)
|
||||
return self._with_native(df)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.dropna())
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
native_dtypes = self.native.dtypes
|
||||
self._cached_schema = {
|
||||
col: native_to_narwhals_dtype(
|
||||
native_dtypes[col], self._version, self._implementation
|
||||
)
|
||||
for col in self.native.columns
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
|
||||
return self._with_native(self.native.drop(columns=to_drop))
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
# Implementation is based on the following StackOverflow reply:
|
||||
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
|
||||
if order_by is None:
|
||||
return self._with_native(add_row_index(self.native, name))
|
||||
else:
|
||||
plx = self.__narwhals_namespace__()
|
||||
columns = self.columns
|
||||
const_expr = (
|
||||
plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
|
||||
)
|
||||
row_index_expr = (
|
||||
plx.col(name)
|
||||
.cum_sum(reverse=False)
|
||||
.over(partition_by=[], order_by=order_by)
|
||||
- 1
|
||||
)
|
||||
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
return self._with_native(self.native.rename(columns=mapping))
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
if keep == "none":
|
||||
subset = subset or self.columns
|
||||
token = generate_temporary_column_name(n_bytes=8, columns=subset)
|
||||
ser = self.native.groupby(subset).size().rename(token)
|
||||
ser = ser[ser == 1]
|
||||
unique = ser.reset_index().drop(columns=token)
|
||||
result = self.native.merge(unique, on=subset, how="inner")
|
||||
else:
|
||||
mapped_keep = {"any": "first"}.get(keep, keep)
|
||||
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
|
||||
return self._with_native(result)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
ascending: bool | list[bool] = not descending
|
||||
else:
|
||||
ascending = [not d for d in descending]
|
||||
position = "last" if nulls_last else "first"
|
||||
return self._with_native(
|
||||
self.native.sort_values(list(by), ascending=ascending, na_position=position)
|
||||
)
|
||||
|
||||
def _join_inner(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
return self.native.merge(
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
how="inner",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_left(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
result_native = self.native.merge(
|
||||
other.native,
|
||||
how="left",
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
extra = [
|
||||
right_key if right_key not in self.columns else f"{right_key}{suffix}"
|
||||
for left_key, right_key in zip(left_on, right_on)
|
||||
if right_key != left_key
|
||||
]
|
||||
return result_native.drop(columns=extra)
|
||||
|
||||
def _join_full(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
# dask does not retain keys post-join
|
||||
# we must append the suffix to each key before-hand
|
||||
|
||||
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
|
||||
other_native = other.native.rename(columns=right_on_mapper)
|
||||
check_column_names_are_unique(other_native.columns)
|
||||
right_suffixed = list(right_on_mapper.values())
|
||||
return self.native.merge(
|
||||
other_native,
|
||||
left_on=left_on,
|
||||
right_on=right_suffixed,
|
||||
how="outer",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
return (
|
||||
self.native.assign(**{key_token: 0})
|
||||
.merge(
|
||||
other.native.assign(**{key_token: 0}),
|
||||
how="inner",
|
||||
left_on=key_token,
|
||||
right_on=key_token,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
.drop(columns=key_token)
|
||||
)
|
||||
|
||||
def _join_semi(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
return self.native.merge(
|
||||
other_native, how="inner", left_on=left_on, right_on=left_on
|
||||
)
|
||||
|
||||
def _join_anti(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
indicator_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
df = self.native.merge(
|
||||
other_native,
|
||||
how="left",
|
||||
indicator=indicator_token, # pyright: ignore[reportArgumentType]
|
||||
left_on=left_on,
|
||||
right_on=left_on,
|
||||
)
|
||||
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
|
||||
|
||||
def _join_filter_rename(
|
||||
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
|
||||
) -> dd.DataFrame:
|
||||
"""Helper function to avoid creating extra columns and row duplication.
|
||||
|
||||
Used in `"anti"` and `"semi`" join's.
|
||||
|
||||
Notice that a native object is returned.
|
||||
"""
|
||||
other_native: Incomplete = other.native
|
||||
# rename to avoid creating extra columns in join
|
||||
return (
|
||||
select_columns_by_name(other_native, columns_to_select, self._implementation)
|
||||
.rename(columns=columns_mapping)
|
||||
.drop_duplicates()
|
||||
)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
if how == "cross":
|
||||
result = self._join_cross(other=other, suffix=suffix)
|
||||
|
||||
elif left_on is None or right_on is None: # pragma: no cover
|
||||
raise ValueError(left_on, right_on)
|
||||
|
||||
elif how == "inner":
|
||||
result = self._join_inner(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "anti":
|
||||
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "semi":
|
||||
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "left":
|
||||
result = self._join_left(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "full":
|
||||
result = self._join_full(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
else:
|
||||
assert_never(how)
|
||||
return self._with_native(result)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
plx = self.__native_namespace__()
|
||||
return self._with_native(
|
||||
plx.merge_asof(
|
||||
self.native,
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
left_by=by_left,
|
||||
right_by=by_right,
|
||||
direction=strategy,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
|
||||
) -> DaskLazyGroupBy:
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
|
||||
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def tail(self, n: int) -> Self: # pragma: no cover
|
||||
native_frame = self.native
|
||||
n_partitions = native_frame.npartitions
|
||||
|
||||
if n_partitions == 1:
|
||||
return self._with_native(self.native.tail(n=n, compute=False))
|
||||
else:
|
||||
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
plx = self.__narwhals_namespace__()
|
||||
return (
|
||||
self.with_row_index(row_index_token, order_by=None)
|
||||
.filter(
|
||||
(plx.col(row_index_token) >= offset)
|
||||
& ((plx.col(row_index_token) - offset) % n == 0)
|
||||
)
|
||||
.drop([row_index_token], strict=False)
|
||||
)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
var_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
self.native.to_parquet(file)
|
||||
|
||||
explode = not_implemented()
|
||||
692
.venv/lib/python3.12/site-packages/narwhals/_dask/expr.py
Normal file
692
.venv/lib/python3.12/site-packages/narwhals/_dask/expr.py
Normal file
@ -0,0 +1,692 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
from narwhals._compliant import DepthTrackingExpr, LazyExpr
|
||||
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
|
||||
from narwhals._dask.expr_str import DaskExprStringNamespace
|
||||
from narwhals._dask.utils import (
|
||||
add_row_index,
|
||||
maybe_evaluate_expr,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
|
||||
from narwhals._pandas_like.utils import native_to_narwhals_dtype
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RollingInterpolationMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
|
||||
class DaskExpr(
|
||||
LazyExpr["DaskLazyFrame", "dx.Series"],
|
||||
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DaskLazyFrame, dx.Series],
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[DaskLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self._call(df)
|
||||
|
||||
def __narwhals_expr__(self) -> None: ...
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
|
||||
# that raised a KeyError for result[0] during collection.
|
||||
return [result.loc[0][0] for result in self(df)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[DaskLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
try:
|
||||
return [
|
||||
df._native_frame[column_name]
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
return [df.native.iloc[:, i] for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def _with_callable(
|
||||
self,
|
||||
# First argument to `call` should be `dx.Series`
|
||||
call: Callable[..., dx.Series],
|
||||
/,
|
||||
expr_name: str = "",
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
**expressifiable_args: Self | Any,
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
native_results: list[dx.Series] = []
|
||||
native_series_list = self._call(df)
|
||||
other_native_series = {
|
||||
key: maybe_evaluate_expr(df, value)
|
||||
for key, value in expressifiable_args.items()
|
||||
}
|
||||
for native_series in native_series_list:
|
||||
result_native = call(native_series, **other_native_series)
|
||||
native_results.append(result_native)
|
||||
return native_results
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=f"{self._function_name}->{expr_name}",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=scalar_kwargs,
|
||||
)
|
||||
|
||||
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
||||
current_alias_output_names = self._alias_output_names
|
||||
alias_output_names = (
|
||||
None
|
||||
if func is None
|
||||
else func
|
||||
if current_alias_output_names is None
|
||||
else lambda output_names: func(current_alias_output_names(output_names))
|
||||
)
|
||||
return type(self)(
|
||||
call=self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__add__(other), "__add__", other=other
|
||||
)
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__sub__(other), "__sub__", other=other
|
||||
)
|
||||
|
||||
def __rsub__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other - expr, "__rsub__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__mul__(other), "__mul__", other=other
|
||||
)
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__truediv__(other), "__truediv__", other=other
|
||||
)
|
||||
|
||||
def __rtruediv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other / expr, "__rtruediv__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other
|
||||
)
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other // expr, "__rfloordiv__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__pow__(other), "__pow__", other=other
|
||||
)
|
||||
|
||||
def __rpow__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other**expr, "__rpow__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__mod__(other), "__mod__", other=other
|
||||
)
|
||||
|
||||
def __rmod__(self, other: Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: other % expr, "__rmod__", other=other
|
||||
).alias("literal")
|
||||
|
||||
def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override]
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__eq__(other), "__eq__", other=other
|
||||
)
|
||||
|
||||
def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override]
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__ne__(other), "__ne__", other=other
|
||||
)
|
||||
|
||||
def __ge__(self, other: DaskExpr | Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__ge__(other), "__ge__", other=other
|
||||
)
|
||||
|
||||
def __gt__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__gt__(other), "__gt__", other=other
|
||||
)
|
||||
|
||||
def __le__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__le__(other), "__le__", other=other
|
||||
)
|
||||
|
||||
def __lt__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__lt__(other), "__lt__", other=other
|
||||
)
|
||||
|
||||
def __and__(self, other: DaskExpr | Any) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__and__(other), "__and__", other=other
|
||||
)
|
||||
|
||||
def __or__(self, other: DaskExpr) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, other: expr.__or__(other), "__or__", other=other
|
||||
)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
|
||||
|
||||
def mean(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
|
||||
|
||||
def median(self) -> Self:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
def func(s: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
|
||||
if not dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
return s.median_approximate().to_series()
|
||||
|
||||
return self._with_callable(func, "median")
|
||||
|
||||
def min(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.min().to_series(), "min")
|
||||
|
||||
def max(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.max().to_series(), "max")
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.std(ddof=ddof).to_series(),
|
||||
"std",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.var(ddof=ddof).to_series(),
|
||||
"var",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def skew(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
|
||||
|
||||
def shift(self, n: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.shift(n), "shift")
|
||||
|
||||
def cum_sum(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
# https://github.com/dask/dask/issues/11802
|
||||
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(
|
||||
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
|
||||
)
|
||||
|
||||
def cum_min(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
|
||||
|
||||
def cum_max(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
|
||||
|
||||
def cum_prod(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).sum(),
|
||||
"rolling_sum",
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).mean(),
|
||||
"rolling_mean",
|
||||
)
|
||||
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).var(),
|
||||
"rolling_var",
|
||||
)
|
||||
else:
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).std(),
|
||||
"rolling_std",
|
||||
)
|
||||
else:
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def sum(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
|
||||
|
||||
def count(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.count().to_series(), "count")
|
||||
|
||||
def round(self, decimals: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.round(decimals), "round")
|
||||
|
||||
def unique(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.unique(), "unique")
|
||||
|
||||
def drop_nulls(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
|
||||
|
||||
def abs(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.abs(), "abs")
|
||||
|
||||
def all(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.all(
|
||||
axis=None, skipna=True, split_every=False, out=None
|
||||
).to_series(),
|
||||
"all",
|
||||
)
|
||||
|
||||
def any(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
|
||||
"any",
|
||||
)
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
if value is not None:
|
||||
res_ser = expr.fillna(value)
|
||||
else:
|
||||
res_ser = (
|
||||
expr.ffill(limit=limit)
|
||||
if strategy == "forward"
|
||||
else expr.bfill(limit=limit)
|
||||
)
|
||||
return res_ser
|
||||
|
||||
return self._with_callable(func, "fillna")
|
||||
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, lower_bound, upper_bound: expr.clip(
|
||||
lower=lower_bound, upper=upper_bound
|
||||
),
|
||||
"clip",
|
||||
lower_bound=lower_bound,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
|
||||
def diff(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.diff(), "diff")
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
|
||||
)
|
||||
|
||||
def is_null(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isna(), "is_null")
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
expr.dtype, self._version, self._implementation
|
||||
)
|
||||
if dtype.is_numeric():
|
||||
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
|
||||
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self._with_callable(func, "is_null")
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.size.to_series(), "len")
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
if interpolation == "linear":
|
||||
|
||||
def func(expr: dx.Series, quantile: float) -> dx.Series:
|
||||
if expr.npartitions > 1:
|
||||
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
|
||||
raise NotImplementedError(msg)
|
||||
return expr.quantile(
|
||||
q=quantile, method="dask"
|
||||
).to_series() # pragma: no cover
|
||||
|
||||
return self._with_callable(func, "quantile", quantile=quantile)
|
||||
else:
|
||||
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def is_first_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
|
||||
return frame[col_token].isin(first_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_first_distinct")
|
||||
|
||||
def is_last_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
|
||||
return frame[col_token].isin(last_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_last_distinct")
|
||||
|
||||
def is_unique(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
return (
|
||||
expr.to_frame()
|
||||
.groupby(_name, dropna=False)
|
||||
.transform("size", meta=(_name, int))
|
||||
== 1
|
||||
)
|
||||
|
||||
return self._with_callable(func, "is_unique")
|
||||
|
||||
def is_in(self, other: Any) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isin(other), "is_in")
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.isna().sum().to_series(), "null_count"
|
||||
)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
# pandas is a required dependency of dask so it's safe to import this
|
||||
from narwhals._pandas_like.group_by import PandasLikeGroupBy
|
||||
|
||||
if not partition_by:
|
||||
assert order_by # noqa: S101
|
||||
|
||||
# This is something like `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self(df.sort(*order_by, descending=False, nulls_last=False))
|
||||
elif not self._is_elementary(): # pragma: no cover
|
||||
msg = (
|
||||
"Only elementary expressions are supported for `.over` in dask.\n\n"
|
||||
"Please see: "
|
||||
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
elif order_by:
|
||||
# Wrong results https://github.com/dask/dask/issues/11806.
|
||||
msg = "`over` with `order_by` is not yet supported in Dask."
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
function_name = PandasLikeGroupBy._leaf_name(self)
|
||||
try:
|
||||
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
|
||||
except KeyError:
|
||||
# window functions are unsupported: https://github.com/dask/dask/issues/11806
|
||||
msg = (
|
||||
f"Unsupported function: {function_name} in `over` context.\n\n"
|
||||
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
|
||||
)
|
||||
raise NotImplementedError(msg) from None
|
||||
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# https://github.com/dask/dask/issues/11804
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*`meta` is not specified",
|
||||
category=UserWarning,
|
||||
)
|
||||
grouped = df.native.groupby(partition_by)
|
||||
if dask_function_name == "size":
|
||||
if len(output_names) != 1: # pragma: no cover
|
||||
msg = "Safety check failed, please report a bug."
|
||||
raise AssertionError(msg)
|
||||
res_native = grouped.transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
).to_frame(output_names[0])
|
||||
else:
|
||||
res_native = grouped[list(output_names)].transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
)
|
||||
result_frame = df._with_native(
|
||||
res_native.rename(columns=dict(zip(output_names, aliases)))
|
||||
).native
|
||||
return [result_frame[name] for name in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
return expr.astype(native_dtype)
|
||||
|
||||
return self._with_callable(func, "cast")
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.isfinite, "is_finite")
|
||||
|
||||
def log(self, base: float) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
def _log(expr: dx.Series) -> dx.Series:
|
||||
return da.log(expr) / da.log(base)
|
||||
|
||||
return self._with_callable(_log, "log")
|
||||
|
||||
def exp(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.exp, "exp")
|
||||
|
||||
def sqrt(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.sqrt, "sqrt")
|
||||
|
||||
@property
|
||||
def str(self) -> DaskExprStringNamespace:
|
||||
return DaskExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DaskExprDateTimeNamespace:
|
||||
return DaskExprDateTimeNamespace(self)
|
||||
|
||||
arg_max: not_implemented = not_implemented()
|
||||
arg_min: not_implemented = not_implemented()
|
||||
arg_true: not_implemented = not_implemented()
|
||||
ewm_mean: not_implemented = not_implemented()
|
||||
gather_every: not_implemented = not_implemented()
|
||||
head: not_implemented = not_implemented()
|
||||
map_batches: not_implemented = not_implemented()
|
||||
mode: not_implemented = not_implemented()
|
||||
sample: not_implemented = not_implemented()
|
||||
rank: not_implemented = not_implemented()
|
||||
replace_strict: not_implemented = not_implemented()
|
||||
sort: not_implemented = not_implemented()
|
||||
tail: not_implemented = not_implemented()
|
||||
|
||||
# namespaces
|
||||
list: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
cat: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
struct: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
176
.venv/lib/python3.12/site-packages/narwhals/_dask/expr_dt.py
Normal file
176
.venv/lib/python3.12/site-packages/narwhals/_dask/expr_dt.py
Normal file
@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import DateTimeNamespace
|
||||
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._pandas_like.utils import (
|
||||
ALIAS_DICT,
|
||||
calculate_timestamp_date,
|
||||
calculate_timestamp_datetime,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
|
||||
class DaskExprDateTimeNamespace(
|
||||
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
|
||||
):
|
||||
def date(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
|
||||
|
||||
def year(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
|
||||
|
||||
def month(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
|
||||
|
||||
def day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
|
||||
|
||||
def hour(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
|
||||
|
||||
def minute(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
|
||||
|
||||
def second(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
|
||||
|
||||
def millisecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond // 1000, "millisecond"
|
||||
)
|
||||
|
||||
def microsecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond, "microsecond"
|
||||
)
|
||||
|
||||
def nanosecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
|
||||
)
|
||||
|
||||
def ordinal_day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.dayofyear, "ordinal_day"
|
||||
)
|
||||
|
||||
def weekday(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
|
||||
"weekday",
|
||||
)
|
||||
|
||||
def to_string(self, format: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
|
||||
"strftime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
|
||||
if time_zone is not None
|
||||
else expr.dt.tz_localize(None),
|
||||
"tz_localize",
|
||||
time_zone=time_zone,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> DaskExpr:
|
||||
def func(s: dx.Series, time_zone: str) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
if dtype.time_zone is None: # type: ignore[attr-defined]
|
||||
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
|
||||
|
||||
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
|
||||
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
|
||||
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
is_pyarrow_dtype = "pyarrow" in str(dtype)
|
||||
mask_na = s.isna()
|
||||
dtypes = self.compliant._version.dtypes
|
||||
if dtype == dtypes.Date:
|
||||
# Date is only supported in pandas dtypes if pyarrow-backed
|
||||
s_cast = s.astype("Int32[pyarrow]")
|
||||
result = calculate_timestamp_date(s_cast, time_unit)
|
||||
elif isinstance(dtype, dtypes.Datetime):
|
||||
original_time_unit = dtype.time_unit
|
||||
s_cast = (
|
||||
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
|
||||
)
|
||||
result = calculate_timestamp_datetime(
|
||||
s_cast, original_time_unit, time_unit
|
||||
)
|
||||
else:
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
return result.where(~mask_na) # pyright: ignore[reportReturnType]
|
||||
|
||||
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
|
||||
|
||||
def total_minutes(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
|
||||
)
|
||||
|
||||
def total_seconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
|
||||
)
|
||||
|
||||
def total_milliseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
|
||||
"total_milliseconds",
|
||||
)
|
||||
|
||||
def total_microseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
|
||||
"total_microseconds",
|
||||
)
|
||||
|
||||
def total_nanoseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
|
||||
)
|
||||
|
||||
def truncate(self, every: str) -> DaskExpr:
|
||||
interval = Interval.parse(every)
|
||||
unit = interval.unit
|
||||
if unit in {"mo", "q", "y"}:
|
||||
msg = f"Truncating to {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
|
||||
|
||||
def offset_by(self, by: str) -> DaskExpr:
|
||||
def func(s: dx.Series, by: str) -> dx.Series:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
unit = interval.unit
|
||||
if unit in {"y", "q", "mo", "d", "ns"}:
|
||||
msg = f"Offsetting by {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
offset = interval.to_timedelta()
|
||||
return s.add(offset)
|
||||
|
||||
return self.compliant._with_callable(func, "offset_by", by=by)
|
||||
106
.venv/lib/python3.12/site-packages/narwhals/_dask/expr_str.py
Normal file
106
.venv/lib/python3.12/site-packages/narwhals/_dask/expr_str.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StringNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
|
||||
class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]):
|
||||
def len_chars(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.str.len(), "len")
|
||||
|
||||
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, pattern, value, literal, n: expr.str.replace(
|
||||
pattern, value, regex=not literal, n=n
|
||||
),
|
||||
"replace",
|
||||
pattern=pattern,
|
||||
value=value,
|
||||
literal=literal,
|
||||
n=n,
|
||||
)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, pattern, value, literal: expr.str.replace(
|
||||
pattern, value, n=-1, regex=not literal
|
||||
),
|
||||
"replace",
|
||||
pattern=pattern,
|
||||
value=value,
|
||||
literal=literal,
|
||||
)
|
||||
|
||||
def strip_chars(self, characters: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, characters: expr.str.strip(characters),
|
||||
"strip",
|
||||
characters=characters,
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix
|
||||
)
|
||||
|
||||
def ends_with(self, suffix: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix
|
||||
)
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, pattern, literal: expr.str.contains(
|
||||
pat=pattern, regex=not literal
|
||||
),
|
||||
"contains",
|
||||
pattern=pattern,
|
||||
literal=literal,
|
||||
)
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, offset, length: expr.str.slice(
|
||||
start=offset, stop=offset + length if length else None
|
||||
),
|
||||
"slice",
|
||||
offset=offset,
|
||||
length=length,
|
||||
)
|
||||
|
||||
def split(self, by: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, by: expr.str.split(pat=by), "split", by=by
|
||||
)
|
||||
|
||||
def to_datetime(self, format: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: dd.to_datetime(expr, format=format),
|
||||
"to_datetime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def to_uppercase(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.str.upper(), "to_uppercase"
|
||||
)
|
||||
|
||||
def to_lowercase(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.str.lower(), "to_lowercase"
|
||||
)
|
||||
|
||||
def zfill(self, width: int) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, width: expr.str.zfill(width), "zfill", width=width
|
||||
)
|
||||
|
||||
to_date = not_implemented()
|
||||
146
.venv/lib/python3.12/site-packages/narwhals/_dask/group_by.py
Normal file
146
.venv/lib/python3.12/site-packages/narwhals/_dask/group_by.py
Normal file
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._compliant import DepthTrackingGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from dask.dataframe.api import GroupBy as _DaskGroupBy
|
||||
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
|
||||
_AggFn: TypeAlias = Callable[..., Any]
|
||||
|
||||
else:
|
||||
try:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import dask_expr as dx
|
||||
_DaskGroupBy = dx._groupby.GroupBy
|
||||
|
||||
Aggregation: TypeAlias = "str | _AggFn"
|
||||
"""The name of an aggregation function, or the function itself."""
|
||||
|
||||
|
||||
def n_unique() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.nunique(dropna=False)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.sum()
|
||||
|
||||
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def _all() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.all(skipna=True)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.all(skipna=True)
|
||||
|
||||
return dd.Aggregation(name="all", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def _any() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.any(skipna=True)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.any(skipna=True)
|
||||
|
||||
return dd.Aggregation(name="any", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def var(ddof: int) -> _AggFn:
|
||||
return partial(_DaskGroupBy.var, ddof=ddof)
|
||||
|
||||
|
||||
def std(ddof: int) -> _AggFn:
|
||||
return partial(_DaskGroupBy.std, ddof=ddof)
|
||||
|
||||
|
||||
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": std,
|
||||
"var": var,
|
||||
"len": "size",
|
||||
"n_unique": n_unique,
|
||||
"count": "count",
|
||||
"quantile": "quantile",
|
||||
"all": _all,
|
||||
"any": _any,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: DaskLazyFrame,
|
||||
keys: Sequence[DaskExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
|
||||
df, keys=keys
|
||||
)
|
||||
self._grouped = self.compliant.native.groupby(
|
||||
self._keys, dropna=drop_null_keys, observed=True
|
||||
)
|
||||
|
||||
def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
if not exprs:
|
||||
# No aggregation provided
|
||||
return (
|
||||
self.compliant.simple_select(*self._keys)
|
||||
.unique(self._keys, keep="any")
|
||||
.rename(dict(zip(self._keys, self._output_key_names)))
|
||||
)
|
||||
|
||||
self._ensure_all_simple(exprs)
|
||||
# This should be the fastpath, but cuDF is too far behind to use it.
|
||||
# - https://github.com/rapidsai/cudf/issues/15118
|
||||
# - https://github.com/rapidsai/cudf/issues/15084
|
||||
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
column = self._keys[0]
|
||||
agg_fn = self._remap_expr_name(expr._function_name)
|
||||
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
|
||||
continue
|
||||
|
||||
# e.g. `agg(nw.mean('a'))`
|
||||
agg_fn = self._remap_expr_name(self._leaf_name(expr))
|
||||
# deal with n_unique case in a "lazy" mode to not depend on dask globally
|
||||
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
|
||||
simple_aggregations.update(
|
||||
(alias, (output_name, agg_fn))
|
||||
for alias, output_name in zip(aliases, output_names)
|
||||
)
|
||||
return DaskLazyFrame(
|
||||
self._grouped.agg(**simple_aggregations).reset_index(),
|
||||
version=self.compliant._version,
|
||||
).rename(dict(zip(self._keys, self._output_key_names)))
|
||||
350
.venv/lib/python3.12/site-packages/narwhals/_dask/namespace.py
Normal file
350
.venv/lib/python3.12/site-packages/narwhals/_dask/namespace.py
Normal file
@ -0,0 +1,350 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import dask.dataframe as dd
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import (
|
||||
CompliantThen,
|
||||
CompliantWhen,
|
||||
DepthTrackingNamespace,
|
||||
LazyNamespace,
|
||||
)
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.selectors import DaskSelectorNamespace
|
||||
from narwhals._dask.utils import (
|
||||
align_series_full_broadcast,
|
||||
narwhals_to_native_dtype,
|
||||
validate_comparand,
|
||||
)
|
||||
from narwhals._expression_parsing import (
|
||||
ExprKind,
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class DaskNamespace(
|
||||
LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
|
||||
DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
@property
|
||||
def selectors(self) -> DaskSelectorNamespace:
|
||||
return DaskSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[DaskExpr]:
|
||||
return DaskExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[DaskLazyFrame]:
|
||||
return DaskLazyFrame
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
if dtype is not None:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
|
||||
else:
|
||||
native_pd_series = pd.Series([value], name="literal")
|
||||
npartitions = df._native_frame.npartitions
|
||||
dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
|
||||
return [dask_series[0].to_series()]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# We don't allow dataframes with 0 columns, so `[0]` is safe.
|
||||
return [df._native_frame[df.columns[0]].size.to_series()]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def all_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = (s for _expr in exprs for s in _expr(df))
|
||||
# Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python
|
||||
# objects in `object` dtype, so we don't need the same check we have for pandas-like.
|
||||
it = (
|
||||
(
|
||||
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
|
||||
s if s.dtype == "bool" else s.fillna(True) # noqa: FBT003
|
||||
for s in series
|
||||
)
|
||||
if ignore_nulls
|
||||
else series
|
||||
)
|
||||
return [reduce(operator.and_, align_series_full_broadcast(df, *it))]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def any_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = (s for _expr in exprs for s in _expr(df))
|
||||
# Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python
|
||||
# objects in `object` dtype, so we don't need the same check we have for pandas-like.
|
||||
it = (
|
||||
(
|
||||
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
|
||||
s if s.dtype == "bool" else s.fillna(False) # noqa: FBT003
|
||||
for s in series
|
||||
)
|
||||
if ignore_nulls
|
||||
else series
|
||||
)
|
||||
return [reduce(operator.or_, align_series_full_broadcast(df, *it))]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
return [dd.concat(series, axis=1).sum(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
|
||||
) -> DaskLazyFrame:
|
||||
if not items:
|
||||
msg = "No items to concatenate" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
dfs = [i._native_frame for i in items]
|
||||
cols_0 = dfs[0].columns
|
||||
if how == "vertical":
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.columns
|
||||
if not (
|
||||
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
|
||||
):
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0.to_list()}\n"
|
||||
f" - dataframe {i}: {cols_current.to_list()}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return DaskLazyFrame(
|
||||
dd.concat(dfs, axis=0, join="inner"), version=self._version
|
||||
)
|
||||
if how == "diagonal":
|
||||
return DaskLazyFrame(
|
||||
dd.concat(dfs, axis=0, join="outer"), version=self._version
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
|
||||
non_na = align_series_full_broadcast(
|
||||
df, *(1 - s.isna() for s in expr_results)
|
||||
)
|
||||
num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
|
||||
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
|
||||
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
|
||||
return [dd.concat(series, axis=1).min(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
|
||||
return [dd.concat(series, axis=1).max(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def when(self, predicate: DaskExpr) -> DaskWhen:
|
||||
return DaskWhen.from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
|
||||
) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
series = (
|
||||
s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
|
||||
)
|
||||
null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, null_mask)
|
||||
result = reduce(lambda x, y: x + separator + y, series).where(
|
||||
~null_mask_result, None
|
||||
)
|
||||
else:
|
||||
init_value, *values = [
|
||||
s.where(~nm, "") for s, nm in zip(series, null_mask)
|
||||
]
|
||||
|
||||
separators = (
|
||||
nm.map({True: "", False: separator}, meta=str)
|
||||
for nm in null_mask[:-1]
|
||||
)
|
||||
result = reduce(
|
||||
operator.add, (s + v for s, v in zip(separators, values)), init_value
|
||||
)
|
||||
|
||||
return [result]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=getattr(
|
||||
exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
|
||||
),
|
||||
alias_output_names=getattr(exprs[0], "_alias_output_names", None),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def coalesce(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
return [reduce(lambda x, y: x.fillna(y), series)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
|
||||
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
|
||||
@property
|
||||
def _then(self) -> type[DaskThen]:
|
||||
return DaskThen
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
then_value = (
|
||||
self._then_value(df)[0]
|
||||
if isinstance(self._then_value, DaskExpr)
|
||||
else self._then_value
|
||||
)
|
||||
otherwise_value = (
|
||||
self._otherwise_value(df)[0]
|
||||
if isinstance(self._otherwise_value, DaskExpr)
|
||||
else self._otherwise_value
|
||||
)
|
||||
|
||||
condition = self._condition(df)[0]
|
||||
# re-evaluate DataFrame if the condition aggregates to force
|
||||
# then/otherwise to be evaluated against the aggregated frame
|
||||
assert self._condition._metadata is not None # noqa: S101
|
||||
if self._condition._metadata.is_scalar_like:
|
||||
new_df = df._with_native(condition.to_frame())
|
||||
condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
|
||||
df = new_df
|
||||
|
||||
if self._otherwise_value is None:
|
||||
(condition, then_series) = align_series_full_broadcast(
|
||||
df, condition, then_value
|
||||
)
|
||||
validate_comparand(condition, then_series)
|
||||
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
|
||||
(condition, then_series, otherwise_series) = align_series_full_broadcast(
|
||||
df, condition, then_value, otherwise_value
|
||||
)
|
||||
validate_comparand(condition, then_series)
|
||||
validate_comparand(condition, otherwise_series)
|
||||
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr):
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
||||
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx # noqa: F401
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401
|
||||
|
||||
|
||||
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]):
|
||||
@property
|
||||
def _selector(self) -> type[DaskSelector]:
|
||||
return DaskSelector
|
||||
|
||||
|
||||
class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> DaskExpr:
|
||||
return DaskExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
151
.venv/lib/python3.12/site-packages/narwhals/_dask/utils.py
Normal file
151
.venv/lib/python3.12/site-packages/narwhals/_dask/utils.py
Normal file
@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._pandas_like.utils import select_columns_by_name
|
||||
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
|
||||
from narwhals.dependencies import get_pyarrow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import dask.dataframe as dd
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame, Incomplete
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.typing import IntoDType
|
||||
else:
|
||||
try:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import dask_expr as dx
|
||||
|
||||
|
||||
def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
if isinstance(obj, DaskExpr):
|
||||
results = obj._call(df)
|
||||
assert len(results) == 1 # debug assertion # noqa: S101
|
||||
return results[0]
|
||||
return obj
|
||||
|
||||
|
||||
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]:
|
||||
native_results: list[tuple[str, dx.Series]] = []
|
||||
for expr in exprs:
|
||||
native_series_list = expr(df)
|
||||
aliases = expr._evaluate_aliases(df)
|
||||
if len(aliases) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(aliases, native_series_list))
|
||||
return native_results
|
||||
|
||||
|
||||
def align_series_full_broadcast(
|
||||
df: DaskLazyFrame, *series: dx.Series | object
|
||||
) -> Sequence[dx.Series]:
|
||||
return [
|
||||
s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"]
|
||||
for s in series
|
||||
] # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame:
|
||||
original_cols = frame.columns
|
||||
df: Incomplete = frame.assign(**{name: 1})
|
||||
return select_columns_by_name(
|
||||
df.assign(**{name: df[name].cumsum(method="blelloch") - 1}),
|
||||
[name, *original_cols],
|
||||
Implementation.DASK,
|
||||
)
|
||||
|
||||
|
||||
def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
|
||||
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
|
||||
# are_co_aligned is a method which cheaply checks if two Dask expressions
|
||||
# have the same index, and therefore don't require index alignment.
|
||||
# If someone only operates on a Dask DataFrame via expressions, then this
|
||||
# should always be the case: expression outputs (by definition) all come from the
|
||||
# same input dataframe, and Dask Series does not have any operations which
|
||||
# change the index. Nonetheless, we perform this safety check anyway.
|
||||
|
||||
# However, we still need to carefully vet which methods we support for Dask, to
|
||||
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
|
||||
# https://github.com/dask/dask-expr/issues/1112.
|
||||
msg = "Objects are not co-aligned, so this operation is not supported for Dask backend"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any: # noqa: C901, PLR0912
|
||||
dtypes = version.dtypes
|
||||
if isinstance_or_issubclass(dtype, dtypes.Float64):
|
||||
return "float64"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Float32):
|
||||
return "float32"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int64):
|
||||
return "int64"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int32):
|
||||
return "int32"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int16):
|
||||
return "int16"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Int8):
|
||||
return "int8"
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt64):
|
||||
return "uint64"
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt32):
|
||||
return "uint32"
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt16):
|
||||
return "uint16"
|
||||
if isinstance_or_issubclass(dtype, dtypes.UInt8):
|
||||
return "uint8"
|
||||
if isinstance_or_issubclass(dtype, dtypes.String):
|
||||
if Implementation.PANDAS._backend_version() >= (2, 0, 0):
|
||||
if get_pyarrow() is not None:
|
||||
return "string[pyarrow]"
|
||||
return "string[python]" # pragma: no cover
|
||||
return "object" # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, dtypes.Boolean):
|
||||
return "bool"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
import pandas as pd
|
||||
|
||||
# NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow
|
||||
# Should be one of the `ListLike*` types
|
||||
# https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505
|
||||
return pd.CategoricalDtype(dtype.categories, ordered=True) # type: ignore[arg-type]
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
|
||||
if isinstance_or_issubclass(dtype, dtypes.Categorical):
|
||||
return "category"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
return "datetime64[us]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Date):
|
||||
return "date32[day][pyarrow]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return "timedelta64[ns]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
|
||||
msg = "Converting to List dtype is not supported yet"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
|
||||
msg = "Converting to Struct dtype is not supported yet"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
msg = "Converting to Array dtype is not supported yet"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Time): # pragma: no cover
|
||||
msg = "Converting to Time dtype is not supported yet"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Binary): # pragma: no cover
|
||||
msg = "Converting to Binary dtype is not supported yet"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
Reference in New Issue
Block a user