some new features
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,482 @@
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...base import is_classifier
|
||||
from ...utils._optional_dependencies import check_matplotlib_support
|
||||
from ...utils.multiclass import unique_labels
|
||||
from .. import confusion_matrix
|
||||
|
||||
|
||||
class ConfusionMatrixDisplay:
|
||||
"""Confusion Matrix visualization.
|
||||
|
||||
It is recommend to use
|
||||
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
|
||||
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
|
||||
attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
confusion_matrix : ndarray of shape (n_classes, n_classes)
|
||||
Confusion matrix.
|
||||
|
||||
display_labels : ndarray of shape (n_classes,), default=None
|
||||
Display labels for plot. If None, display labels are set from 0 to
|
||||
`n_classes - 1`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
im_ : matplotlib AxesImage
|
||||
Image representing the confusion matrix.
|
||||
|
||||
text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \
|
||||
or None
|
||||
Array of matplotlib axes. `None` if `include_values` is false.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with confusion matrix.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the confusion matrix.
|
||||
|
||||
See Also
|
||||
--------
|
||||
confusion_matrix : Compute Confusion Matrix to evaluate the accuracy of a
|
||||
classification.
|
||||
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
|
||||
given an estimator, the data, and the label.
|
||||
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
|
||||
given the true and predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||
... random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> predictions = clf.predict(X_test)
|
||||
>>> cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
|
||||
>>> disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||
... display_labels=clf.classes_)
|
||||
>>> disp.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, confusion_matrix, *, display_labels=None):
|
||||
self.confusion_matrix = confusion_matrix
|
||||
self.display_labels = display_labels
|
||||
|
||||
def plot(
|
||||
self,
|
||||
*,
|
||||
include_values=True,
|
||||
cmap="viridis",
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
im_kw=None,
|
||||
text_kw=None,
|
||||
):
|
||||
"""Plot visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`,
|
||||
the format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
im_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
|
||||
|
||||
text_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.text` call.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
Returns a :class:`~sklearn.metrics.ConfusionMatrixDisplay` instance
|
||||
that contains all the information to plot the confusion matrix.
|
||||
"""
|
||||
check_matplotlib_support("ConfusionMatrixDisplay.plot")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
||||
cm = self.confusion_matrix
|
||||
n_classes = cm.shape[0]
|
||||
|
||||
default_im_kw = dict(interpolation="nearest", cmap=cmap)
|
||||
im_kw = im_kw or {}
|
||||
im_kw = {**default_im_kw, **im_kw}
|
||||
text_kw = text_kw or {}
|
||||
|
||||
self.im_ = ax.imshow(cm, **im_kw)
|
||||
self.text_ = None
|
||||
cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
|
||||
|
||||
if include_values:
|
||||
self.text_ = np.empty_like(cm, dtype=object)
|
||||
|
||||
# print text with appropriate color depending on background
|
||||
thresh = (cm.max() + cm.min()) / 2.0
|
||||
|
||||
for i, j in product(range(n_classes), range(n_classes)):
|
||||
color = cmap_max if cm[i, j] < thresh else cmap_min
|
||||
|
||||
if values_format is None:
|
||||
text_cm = format(cm[i, j], ".2g")
|
||||
if cm.dtype.kind != "f":
|
||||
text_d = format(cm[i, j], "d")
|
||||
if len(text_d) < len(text_cm):
|
||||
text_cm = text_d
|
||||
else:
|
||||
text_cm = format(cm[i, j], values_format)
|
||||
|
||||
default_text_kwargs = dict(ha="center", va="center", color=color)
|
||||
text_kwargs = {**default_text_kwargs, **text_kw}
|
||||
|
||||
self.text_[i, j] = ax.text(j, i, text_cm, **text_kwargs)
|
||||
|
||||
if self.display_labels is None:
|
||||
display_labels = np.arange(n_classes)
|
||||
else:
|
||||
display_labels = self.display_labels
|
||||
if colorbar:
|
||||
fig.colorbar(self.im_, ax=ax)
|
||||
ax.set(
|
||||
xticks=np.arange(n_classes),
|
||||
yticks=np.arange(n_classes),
|
||||
xticklabels=display_labels,
|
||||
yticklabels=display_labels,
|
||||
ylabel="True label",
|
||||
xlabel="Predicted label",
|
||||
)
|
||||
|
||||
ax.set_ylim((n_classes - 0.5, -0.5))
|
||||
plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
|
||||
|
||||
self.figure_ = fig
|
||||
self.ax_ = ax
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
labels=None,
|
||||
sample_weight=None,
|
||||
normalize=None,
|
||||
display_labels=None,
|
||||
include_values=True,
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
cmap="viridis",
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
im_kw=None,
|
||||
text_kw=None,
|
||||
):
|
||||
"""Plot Confusion Matrix given an estimator and some data.
|
||||
|
||||
Read more in the :ref:`User Guide <confusion_matrix>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
labels : array-like of shape (n_classes,), default=None
|
||||
List of labels to index the confusion matrix. This may be used to
|
||||
reorder or select a subset of labels. If `None` is given, those
|
||||
that appear at least once in `y_true` or `y_pred` are used in
|
||||
sorted order.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
normalize : {'true', 'pred', 'all'}, default=None
|
||||
Either to normalize the counts display in the matrix:
|
||||
|
||||
- if `'true'`, the confusion matrix is normalized over the true
|
||||
conditions (e.g. rows);
|
||||
- if `'pred'`, the confusion matrix is normalized over the
|
||||
predicted conditions (e.g. columns);
|
||||
- if `'all'`, the confusion matrix is normalized by the total
|
||||
number of samples;
|
||||
- if `None` (default), the confusion matrix will not be normalized.
|
||||
|
||||
display_labels : array-like of shape (n_classes,), default=None
|
||||
Target names used for plotting. By default, `labels` will be used
|
||||
if it is defined, otherwise the unique labels of `y_true` and
|
||||
`y_pred` will be used.
|
||||
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`, the
|
||||
format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
ax : matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
im_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
|
||||
|
||||
text_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.text` call.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
|
||||
given the true and predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import ConfusionMatrixDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> ConfusionMatrixDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
method_name = f"{cls.__name__}.from_estimator"
|
||||
check_matplotlib_support(method_name)
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(f"{method_name} only supports classifiers")
|
||||
y_pred = estimator.predict(X)
|
||||
|
||||
return cls.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
labels=labels,
|
||||
normalize=normalize,
|
||||
display_labels=display_labels,
|
||||
include_values=include_values,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
xticks_rotation=xticks_rotation,
|
||||
values_format=values_format,
|
||||
colorbar=colorbar,
|
||||
im_kw=im_kw,
|
||||
text_kw=text_kw,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
labels=None,
|
||||
sample_weight=None,
|
||||
normalize=None,
|
||||
display_labels=None,
|
||||
include_values=True,
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
cmap="viridis",
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
im_kw=None,
|
||||
text_kw=None,
|
||||
):
|
||||
"""Plot Confusion Matrix given true and predicted labels.
|
||||
|
||||
Read more in the :ref:`User Guide <confusion_matrix>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
The predicted labels given by the method `predict` of an
|
||||
classifier.
|
||||
|
||||
labels : array-like of shape (n_classes,), default=None
|
||||
List of labels to index the confusion matrix. This may be used to
|
||||
reorder or select a subset of labels. If `None` is given, those
|
||||
that appear at least once in `y_true` or `y_pred` are used in
|
||||
sorted order.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
normalize : {'true', 'pred', 'all'}, default=None
|
||||
Either to normalize the counts display in the matrix:
|
||||
|
||||
- if `'true'`, the confusion matrix is normalized over the true
|
||||
conditions (e.g. rows);
|
||||
- if `'pred'`, the confusion matrix is normalized over the
|
||||
predicted conditions (e.g. columns);
|
||||
- if `'all'`, the confusion matrix is normalized by the total
|
||||
number of samples;
|
||||
- if `None` (default), the confusion matrix will not be normalized.
|
||||
|
||||
display_labels : array-like of shape (n_classes,), default=None
|
||||
Target names used for plotting. By default, `labels` will be used
|
||||
if it is defined, otherwise the unique labels of `y_true` and
|
||||
`y_pred` will be used.
|
||||
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`, the
|
||||
format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
ax : matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
im_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
|
||||
|
||||
text_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.text` call.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
|
||||
given an estimator, the data, and the label.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import ConfusionMatrixDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> y_pred = clf.predict(X_test)
|
||||
>>> ConfusionMatrixDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
|
||||
if display_labels is None:
|
||||
if labels is None:
|
||||
display_labels = unique_labels(y_true, y_pred)
|
||||
else:
|
||||
display_labels = labels
|
||||
|
||||
cm = confusion_matrix(
|
||||
y_true,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
labels=labels,
|
||||
normalize=normalize,
|
||||
)
|
||||
|
||||
disp = cls(confusion_matrix=cm, display_labels=display_labels)
|
||||
|
||||
return disp.plot(
|
||||
include_values=include_values,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
xticks_rotation=xticks_rotation,
|
||||
values_format=values_format,
|
||||
colorbar=colorbar,
|
||||
im_kw=im_kw,
|
||||
text_kw=text_kw,
|
||||
)
|
||||
@ -0,0 +1,332 @@
|
||||
import scipy as sp
|
||||
|
||||
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
|
||||
from .._ranking import det_curve
|
||||
|
||||
|
||||
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
|
||||
"""DET curve visualization.
|
||||
|
||||
It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
|
||||
or :func:`~sklearn.metrics.DetCurveDisplay.from_predictions` to create a
|
||||
visualizer. All parameters are stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fpr : ndarray
|
||||
False positive rate.
|
||||
|
||||
fnr : ndarray
|
||||
False negative rate.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, the estimator name is not shown.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The label of the positive class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
DET Curve.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with DET Curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
|
||||
some data.
|
||||
DetCurveDisplay.from_predictions : Plot DET curve given the true and
|
||||
predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import det_curve, DetCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> y_pred = clf.decision_function(X_test)
|
||||
>>> fpr, fnr, _ = det_curve(y_test, y_pred)
|
||||
>>> display = DetCurveDisplay(
|
||||
... fpr=fpr, fnr=fnr, estimator_name="SVC"
|
||||
... )
|
||||
>>> display.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, *, fpr, fnr, estimator_name=None, pos_label=None):
|
||||
self.fpr = fpr
|
||||
self.fnr = fnr
|
||||
self.estimator_name = estimator_name
|
||||
self.pos_label = pos_label
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
response_method="auto",
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot DET curve given an estimator and data.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the predicted target response. If set
|
||||
to 'auto', :term:`predict_proba` is tried first and if it does not
|
||||
exist :term:`decision_function` is tried next.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The label of the positive class. When `pos_label=None`, if `y_true`
|
||||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
|
||||
error will be raised.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay.from_predictions : Plot DET curve given the true and
|
||||
predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import DetCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> DetCurveDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
y_pred, pos_label, name = cls._validate_and_get_response_values(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
response_method=response_method,
|
||||
pos_label=pos_label,
|
||||
name=name,
|
||||
)
|
||||
|
||||
return cls.from_predictions(
|
||||
y_true=y,
|
||||
y_pred=y_pred,
|
||||
sample_weight=sample_weight,
|
||||
name=name,
|
||||
ax=ax,
|
||||
pos_label=pos_label,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
sample_weight=None,
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot the DET curve given the true and predicted labels.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Target scores, can either be probability estimates of the positive
|
||||
class, confidence values, or non-thresholded measure of decisions
|
||||
(as returned by `decision_function` on some classifiers).
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The label of the positive class. When `pos_label=None`, if `y_true`
|
||||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
|
||||
error will be raised.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, name will be set to
|
||||
`"Classifier"`.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
|
||||
some data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import DetCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> y_pred = clf.decision_function(X_test)
|
||||
>>> DetCurveDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
pos_label_validated, name = cls._validate_from_predictions_params(
|
||||
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
|
||||
)
|
||||
|
||||
fpr, fnr, _ = det_curve(
|
||||
y_true,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
viz = cls(
|
||||
fpr=fpr,
|
||||
fnr=fnr,
|
||||
estimator_name=name,
|
||||
pos_label=pos_label_validated,
|
||||
)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
|
||||
def plot(self, ax=None, *, name=None, **kwargs):
|
||||
"""Plot visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, use `estimator_name` if
|
||||
it is not `None`, otherwise no labeling is shown.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
|
||||
|
||||
line_kwargs = {} if name is None else {"label": name}
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
(self.line_,) = self.ax_.plot(
|
||||
sp.stats.norm.ppf(self.fpr),
|
||||
sp.stats.norm.ppf(self.fnr),
|
||||
**line_kwargs,
|
||||
)
|
||||
info_pos_label = (
|
||||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
|
||||
)
|
||||
|
||||
xlabel = "False Positive Rate" + info_pos_label
|
||||
ylabel = "False Negative Rate" + info_pos_label
|
||||
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
|
||||
|
||||
if "label" in line_kwargs:
|
||||
self.ax_.legend(loc="lower right")
|
||||
|
||||
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
|
||||
tick_locations = sp.stats.norm.ppf(ticks)
|
||||
tick_labels = [
|
||||
"{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s)
|
||||
for s in ticks
|
||||
]
|
||||
self.ax_.set_xticks(tick_locations)
|
||||
self.ax_.set_xticklabels(tick_labels)
|
||||
self.ax_.set_xlim(-3, 3)
|
||||
self.ax_.set_yticks(tick_locations)
|
||||
self.ax_.set_yticklabels(tick_labels)
|
||||
self.ax_.set_ylim(-3, 3)
|
||||
|
||||
return self
|
||||
@ -0,0 +1,504 @@
|
||||
from collections import Counter
|
||||
|
||||
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
|
||||
from .._ranking import average_precision_score, precision_recall_curve
|
||||
|
||||
|
||||
class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):
|
||||
"""Precision Recall visualization.
|
||||
|
||||
It is recommend to use
|
||||
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create
|
||||
a :class:`~sklearn.metrics.PrecisionRecallDisplay`. All parameters are
|
||||
stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
precision : ndarray
|
||||
Precision values.
|
||||
|
||||
recall : ndarray
|
||||
Recall values.
|
||||
|
||||
average_precision : float, default=None
|
||||
Average precision. If None, the average precision is not shown.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, then the estimator name is not shown.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class. If None, the class will not
|
||||
be shown in the legend.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
prevalence_pos_label : float, default=None
|
||||
The prevalence of the positive label. It is used for plotting the
|
||||
chance level line. If None, the chance level line will not be plotted
|
||||
even if `plot_chance_level` is set to True when plotting.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
Precision recall curve.
|
||||
|
||||
chance_level_ : matplotlib Artist or None
|
||||
The chance level line. It is `None` if the chance level is not plotted.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with precision recall curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
See Also
|
||||
--------
|
||||
precision_recall_curve : Compute precision-recall pairs for different
|
||||
probability thresholds.
|
||||
PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given
|
||||
a binary classifier.
|
||||
PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve
|
||||
using predictions from a binary classifier.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) in
|
||||
scikit-learn is computed without any interpolation. To be consistent with
|
||||
this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"` in :meth:`plot`, :meth:`from_estimator`, or
|
||||
:meth:`from_predictions`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import (precision_recall_curve,
|
||||
... PrecisionRecallDisplay)
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||
... random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> predictions = clf.predict(X_test)
|
||||
>>> precision, recall, _ = precision_recall_curve(y_test, predictions)
|
||||
>>> disp = PrecisionRecallDisplay(precision=precision, recall=recall)
|
||||
>>> disp.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision,
|
||||
recall,
|
||||
*,
|
||||
average_precision=None,
|
||||
estimator_name=None,
|
||||
pos_label=None,
|
||||
prevalence_pos_label=None,
|
||||
):
|
||||
self.estimator_name = estimator_name
|
||||
self.precision = precision
|
||||
self.recall = recall
|
||||
self.average_precision = average_precision
|
||||
self.pos_label = pos_label
|
||||
self.prevalence_pos_label = prevalence_pos_label
|
||||
|
||||
def plot(
|
||||
self,
|
||||
ax=None,
|
||||
*,
|
||||
name=None,
|
||||
plot_chance_level=False,
|
||||
chance_level_kw=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot visualization.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : Matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of precision recall curve for labeling. If `None`, use
|
||||
`estimator_name` if not `None`, otherwise no labeling is shown.
|
||||
|
||||
plot_chance_level : bool, default=False
|
||||
Whether to plot the chance level. The chance level is the prevalence
|
||||
of the positive label computed from the data passed during
|
||||
:meth:`from_estimator` or :meth:`from_predictions` call.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
chance_level_kw : dict, default=None
|
||||
Keyword arguments to be passed to matplotlib's `plot` for rendering
|
||||
the chance level line.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`)
|
||||
in scikit-learn is computed without any interpolation. To be consistent
|
||||
with this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
"""
|
||||
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
|
||||
|
||||
line_kwargs = {"drawstyle": "steps-post"}
|
||||
if self.average_precision is not None and name is not None:
|
||||
line_kwargs["label"] = f"{name} (AP = {self.average_precision:0.2f})"
|
||||
elif self.average_precision is not None:
|
||||
line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
|
||||
elif name is not None:
|
||||
line_kwargs["label"] = name
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
(self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs)
|
||||
|
||||
info_pos_label = (
|
||||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
|
||||
)
|
||||
|
||||
xlabel = "Recall" + info_pos_label
|
||||
ylabel = "Precision" + info_pos_label
|
||||
self.ax_.set(
|
||||
xlabel=xlabel,
|
||||
xlim=(-0.01, 1.01),
|
||||
ylabel=ylabel,
|
||||
ylim=(-0.01, 1.01),
|
||||
aspect="equal",
|
||||
)
|
||||
|
||||
if plot_chance_level:
|
||||
if self.prevalence_pos_label is None:
|
||||
raise ValueError(
|
||||
"You must provide prevalence_pos_label when constructing the "
|
||||
"PrecisionRecallDisplay object in order to plot the chance "
|
||||
"level line. Alternatively, you may use "
|
||||
"PrecisionRecallDisplay.from_estimator or "
|
||||
"PrecisionRecallDisplay.from_predictions "
|
||||
"to automatically set prevalence_pos_label"
|
||||
)
|
||||
|
||||
chance_level_line_kw = {
|
||||
"label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})",
|
||||
"color": "k",
|
||||
"linestyle": "--",
|
||||
}
|
||||
if chance_level_kw is not None:
|
||||
chance_level_line_kw.update(chance_level_kw)
|
||||
|
||||
(self.chance_level_,) = self.ax_.plot(
|
||||
(0, 1),
|
||||
(self.prevalence_pos_label, self.prevalence_pos_label),
|
||||
**chance_level_line_kw,
|
||||
)
|
||||
else:
|
||||
self.chance_level_ = None
|
||||
|
||||
if "label" in line_kwargs or plot_chance_level:
|
||||
self.ax_.legend(loc="lower left")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
pos_label=None,
|
||||
drop_intermediate=False,
|
||||
response_method="auto",
|
||||
name=None,
|
||||
ax=None,
|
||||
plot_chance_level=False,
|
||||
chance_level_kw=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot precision-recall curve given an estimator and some data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class when computing the
|
||||
precision and recall metrics. By default, `estimators.classes_[1]`
|
||||
is considered as the positive class.
|
||||
|
||||
drop_intermediate : bool, default=False
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted precision-recall curve. This is useful in order to
|
||||
create lighter precision-recall curves.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'}, \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name for labeling curve. If `None`, no name is used.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
plot_chance_level : bool, default=False
|
||||
Whether to plot the chance level. The chance level is the prevalence
|
||||
of the positive label computed from the data passed during
|
||||
:meth:`from_estimator` or :meth:`from_predictions` call.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
chance_level_kw : dict, default=None
|
||||
Keyword arguments to be passed to matplotlib's `plot` for rendering
|
||||
the chance level line.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
PrecisionRecallDisplay.from_predictions : Plot precision-recall curve
|
||||
using estimated probabilities or output of decision function.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`)
|
||||
in scikit-learn is computed without any interpolation. To be consistent
|
||||
with this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import PrecisionRecallDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.linear_model import LogisticRegression
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = LogisticRegression()
|
||||
>>> clf.fit(X_train, y_train)
|
||||
LogisticRegression()
|
||||
>>> PrecisionRecallDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
y_pred, pos_label, name = cls._validate_and_get_response_values(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
response_method=response_method,
|
||||
pos_label=pos_label,
|
||||
name=name,
|
||||
)
|
||||
|
||||
return cls.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
name=name,
|
||||
pos_label=pos_label,
|
||||
drop_intermediate=drop_intermediate,
|
||||
ax=ax,
|
||||
plot_chance_level=plot_chance_level,
|
||||
chance_level_kw=chance_level_kw,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
sample_weight=None,
|
||||
pos_label=None,
|
||||
drop_intermediate=False,
|
||||
name=None,
|
||||
ax=None,
|
||||
plot_chance_level=False,
|
||||
chance_level_kw=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot precision-recall curve given binary class predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True binary labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Estimated probabilities or output of decision function.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class when computing the
|
||||
precision and recall metrics.
|
||||
|
||||
drop_intermediate : bool, default=False
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted precision-recall curve. This is useful in order to
|
||||
create lighter precision-recall curves.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
name : str, default=None
|
||||
Name for labeling curve. If `None`, name will be set to
|
||||
`"Classifier"`.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
plot_chance_level : bool, default=False
|
||||
Whether to plot the chance level. The chance level is the prevalence
|
||||
of the positive label computed from the data passed during
|
||||
:meth:`from_estimator` or :meth:`from_predictions` call.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
chance_level_kw : dict, default=None
|
||||
Keyword arguments to be passed to matplotlib's `plot` for rendering
|
||||
the chance level line.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
PrecisionRecallDisplay.from_estimator : Plot precision-recall curve
|
||||
using an estimator.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`)
|
||||
in scikit-learn is computed without any interpolation. To be consistent
|
||||
with this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import PrecisionRecallDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.linear_model import LogisticRegression
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = LogisticRegression()
|
||||
>>> clf.fit(X_train, y_train)
|
||||
LogisticRegression()
|
||||
>>> y_pred = clf.predict_proba(X_test)[:, 1]
|
||||
>>> PrecisionRecallDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
pos_label, name = cls._validate_from_predictions_params(
|
||||
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
|
||||
)
|
||||
|
||||
precision, recall, _ = precision_recall_curve(
|
||||
y_true,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
)
|
||||
average_precision = average_precision_score(
|
||||
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
|
||||
)
|
||||
|
||||
class_count = Counter(y_true)
|
||||
prevalence_pos_label = class_count[pos_label] / sum(class_count.values())
|
||||
|
||||
viz = cls(
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
average_precision=average_precision,
|
||||
estimator_name=name,
|
||||
pos_label=pos_label,
|
||||
prevalence_pos_label=prevalence_pos_label,
|
||||
)
|
||||
|
||||
return viz.plot(
|
||||
ax=ax,
|
||||
name=name,
|
||||
plot_chance_level=plot_chance_level,
|
||||
chance_level_kw=chance_level_kw,
|
||||
**kwargs,
|
||||
)
|
||||
@ -0,0 +1,406 @@
|
||||
import numbers
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import _safe_indexing, check_random_state
|
||||
from ...utils._optional_dependencies import check_matplotlib_support
|
||||
|
||||
|
||||
class PredictionErrorDisplay:
|
||||
"""Visualization of the prediction error of a regression model.
|
||||
|
||||
This tool can display "residuals vs predicted" or "actual vs predicted"
|
||||
using scatter plots to qualitatively assess the behavior of a regressor,
|
||||
preferably on held-out data points.
|
||||
|
||||
See the details in the docstrings of
|
||||
:func:`~sklearn.metrics.PredictionErrorDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.PredictionErrorDisplay.from_predictions` to
|
||||
create a visualizer. All parameters are stored as attributes.
|
||||
|
||||
For general information regarding `scikit-learn` visualization tools, read
|
||||
more in the :ref:`Visualization Guide <visualizations>`.
|
||||
For details regarding interpreting these plots, refer to the
|
||||
:ref:`Model Evaluation Guide <visualization_regression_evaluation>`.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : ndarray of shape (n_samples,)
|
||||
True values.
|
||||
|
||||
y_pred : ndarray of shape (n_samples,)
|
||||
Prediction values.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
Optimal line representing `y_true == y_pred`. Therefore, it is a
|
||||
diagonal line for `kind="predictions"` and a horizontal line for
|
||||
`kind="residuals"`.
|
||||
|
||||
errors_lines_ : matplotlib Artist or None
|
||||
Residual lines. If `with_errors=False`, then it is set to `None`.
|
||||
|
||||
scatter_ : matplotlib Artist
|
||||
Scatter data points.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with the different matplotlib axis.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the scatter and lines.
|
||||
|
||||
See Also
|
||||
--------
|
||||
PredictionErrorDisplay.from_estimator : Prediction error visualization
|
||||
given an estimator and some data.
|
||||
PredictionErrorDisplay.from_predictions : Prediction error visualization
|
||||
given the true and predicted targets.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import load_diabetes
|
||||
>>> from sklearn.linear_model import Ridge
|
||||
>>> from sklearn.metrics import PredictionErrorDisplay
|
||||
>>> X, y = load_diabetes(return_X_y=True)
|
||||
>>> ridge = Ridge().fit(X, y)
|
||||
>>> y_pred = ridge.predict(X)
|
||||
>>> display = PredictionErrorDisplay(y_true=y, y_pred=y_pred)
|
||||
>>> display.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, *, y_true, y_pred):
|
||||
self.y_true = y_true
|
||||
self.y_pred = y_pred
|
||||
|
||||
def plot(
|
||||
self,
|
||||
ax=None,
|
||||
*,
|
||||
kind="residual_vs_predicted",
|
||||
scatter_kwargs=None,
|
||||
line_kwargs=None,
|
||||
):
|
||||
"""Plot visualization.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's ``plot``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
|
||||
default="residual_vs_predicted"
|
||||
The type of plot to draw:
|
||||
|
||||
- "actual_vs_predicted" draws the observed values (y-axis) vs.
|
||||
the predicted values (x-axis).
|
||||
- "residual_vs_predicted" draws the residuals, i.e. difference
|
||||
between observed and predicted values, (y-axis) vs. the predicted
|
||||
values (x-axis).
|
||||
|
||||
scatter_kwargs : dict, default=None
|
||||
Dictionary with keywords passed to the `matplotlib.pyplot.scatter`
|
||||
call.
|
||||
|
||||
line_kwargs : dict, default=None
|
||||
Dictionary with keyword passed to the `matplotlib.pyplot.plot`
|
||||
call to draw the optimal line.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PredictionErrorDisplay`
|
||||
|
||||
Object that stores computed values.
|
||||
"""
|
||||
check_matplotlib_support(f"{self.__class__.__name__}.plot")
|
||||
|
||||
expected_kind = ("actual_vs_predicted", "residual_vs_predicted")
|
||||
if kind not in expected_kind:
|
||||
raise ValueError(
|
||||
f"`kind` must be one of {', '.join(expected_kind)}. "
|
||||
f"Got {kind!r} instead."
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if scatter_kwargs is None:
|
||||
scatter_kwargs = {}
|
||||
if line_kwargs is None:
|
||||
line_kwargs = {}
|
||||
|
||||
default_scatter_kwargs = {"color": "tab:blue", "alpha": 0.8}
|
||||
default_line_kwargs = {"color": "black", "alpha": 0.7, "linestyle": "--"}
|
||||
|
||||
scatter_kwargs = {**default_scatter_kwargs, **scatter_kwargs}
|
||||
line_kwargs = {**default_line_kwargs, **line_kwargs}
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots()
|
||||
|
||||
if kind == "actual_vs_predicted":
|
||||
max_value = max(np.max(self.y_true), np.max(self.y_pred))
|
||||
min_value = min(np.min(self.y_true), np.min(self.y_pred))
|
||||
self.line_ = ax.plot(
|
||||
[min_value, max_value], [min_value, max_value], **line_kwargs
|
||||
)[0]
|
||||
|
||||
x_data, y_data = self.y_pred, self.y_true
|
||||
xlabel, ylabel = "Predicted values", "Actual values"
|
||||
|
||||
self.scatter_ = ax.scatter(x_data, y_data, **scatter_kwargs)
|
||||
|
||||
# force to have a squared axis
|
||||
ax.set_aspect("equal", adjustable="datalim")
|
||||
ax.set_xticks(np.linspace(min_value, max_value, num=5))
|
||||
ax.set_yticks(np.linspace(min_value, max_value, num=5))
|
||||
else: # kind == "residual_vs_predicted"
|
||||
self.line_ = ax.plot(
|
||||
[np.min(self.y_pred), np.max(self.y_pred)],
|
||||
[0, 0],
|
||||
**line_kwargs,
|
||||
)[0]
|
||||
self.scatter_ = ax.scatter(
|
||||
self.y_pred, self.y_true - self.y_pred, **scatter_kwargs
|
||||
)
|
||||
xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)"
|
||||
|
||||
ax.set(xlabel=xlabel, ylabel=ylabel)
|
||||
|
||||
self.ax_ = ax
|
||||
self.figure_ = ax.figure
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
kind="residual_vs_predicted",
|
||||
subsample=1_000,
|
||||
random_state=None,
|
||||
ax=None,
|
||||
scatter_kwargs=None,
|
||||
line_kwargs=None,
|
||||
):
|
||||
"""Plot the prediction error given a regressor and some data.
|
||||
|
||||
For general information regarding `scikit-learn` visualization tools,
|
||||
read more in the :ref:`Visualization Guide <visualizations>`.
|
||||
For details regarding interpreting these plots, refer to the
|
||||
:ref:`Model Evaluation Guide <visualization_regression_evaluation>`.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted regressor or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a regressor.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
|
||||
default="residual_vs_predicted"
|
||||
The type of plot to draw:
|
||||
|
||||
- "actual_vs_predicted" draws the observed values (y-axis) vs.
|
||||
the predicted values (x-axis).
|
||||
- "residual_vs_predicted" draws the residuals, i.e. difference
|
||||
between observed and predicted values, (y-axis) vs. the predicted
|
||||
values (x-axis).
|
||||
|
||||
subsample : float, int or None, default=1_000
|
||||
Sampling the samples to be shown on the scatter plot. If `float`,
|
||||
it should be between 0 and 1 and represents the proportion of the
|
||||
original dataset. If `int`, it represents the number of samples
|
||||
display on the scatter plot. If `None`, no subsampling will be
|
||||
applied. by default, 1000 samples or less will be displayed.
|
||||
|
||||
random_state : int or RandomState, default=None
|
||||
Controls the randomness when `subsample` is not `None`.
|
||||
See :term:`Glossary <random_state>` for details.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
scatter_kwargs : dict, default=None
|
||||
Dictionary with keywords passed to the `matplotlib.pyplot.scatter`
|
||||
call.
|
||||
|
||||
line_kwargs : dict, default=None
|
||||
Dictionary with keyword passed to the `matplotlib.pyplot.plot`
|
||||
call to draw the optimal line.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PredictionErrorDisplay`
|
||||
Object that stores the computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
PredictionErrorDisplay : Prediction error visualization for regression.
|
||||
PredictionErrorDisplay.from_predictions : Prediction error visualization
|
||||
given the true and predicted targets.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import load_diabetes
|
||||
>>> from sklearn.linear_model import Ridge
|
||||
>>> from sklearn.metrics import PredictionErrorDisplay
|
||||
>>> X, y = load_diabetes(return_X_y=True)
|
||||
>>> ridge = Ridge().fit(X, y)
|
||||
>>> disp = PredictionErrorDisplay.from_estimator(ridge, X, y)
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_estimator")
|
||||
|
||||
y_pred = estimator.predict(X)
|
||||
|
||||
return cls.from_predictions(
|
||||
y_true=y,
|
||||
y_pred=y_pred,
|
||||
kind=kind,
|
||||
subsample=subsample,
|
||||
random_state=random_state,
|
||||
ax=ax,
|
||||
scatter_kwargs=scatter_kwargs,
|
||||
line_kwargs=line_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
kind="residual_vs_predicted",
|
||||
subsample=1_000,
|
||||
random_state=None,
|
||||
ax=None,
|
||||
scatter_kwargs=None,
|
||||
line_kwargs=None,
|
||||
):
|
||||
"""Plot the prediction error given the true and predicted targets.
|
||||
|
||||
For general information regarding `scikit-learn` visualization tools,
|
||||
read more in the :ref:`Visualization Guide <visualizations>`.
|
||||
For details regarding interpreting these plots, refer to the
|
||||
:ref:`Model Evaluation Guide <visualization_regression_evaluation>`.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True target values.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Predicted target values.
|
||||
|
||||
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
|
||||
default="residual_vs_predicted"
|
||||
The type of plot to draw:
|
||||
|
||||
- "actual_vs_predicted" draws the observed values (y-axis) vs.
|
||||
the predicted values (x-axis).
|
||||
- "residual_vs_predicted" draws the residuals, i.e. difference
|
||||
between observed and predicted values, (y-axis) vs. the predicted
|
||||
values (x-axis).
|
||||
|
||||
subsample : float, int or None, default=1_000
|
||||
Sampling the samples to be shown on the scatter plot. If `float`,
|
||||
it should be between 0 and 1 and represents the proportion of the
|
||||
original dataset. If `int`, it represents the number of samples
|
||||
display on the scatter plot. If `None`, no subsampling will be
|
||||
applied. by default, 1000 samples or less will be displayed.
|
||||
|
||||
random_state : int or RandomState, default=None
|
||||
Controls the randomness when `subsample` is not `None`.
|
||||
See :term:`Glossary <random_state>` for details.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
scatter_kwargs : dict, default=None
|
||||
Dictionary with keywords passed to the `matplotlib.pyplot.scatter`
|
||||
call.
|
||||
|
||||
line_kwargs : dict, default=None
|
||||
Dictionary with keyword passed to the `matplotlib.pyplot.plot`
|
||||
call to draw the optimal line.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PredictionErrorDisplay`
|
||||
Object that stores the computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
PredictionErrorDisplay : Prediction error visualization for regression.
|
||||
PredictionErrorDisplay.from_estimator : Prediction error visualization
|
||||
given an estimator and some data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import load_diabetes
|
||||
>>> from sklearn.linear_model import Ridge
|
||||
>>> from sklearn.metrics import PredictionErrorDisplay
|
||||
>>> X, y = load_diabetes(return_X_y=True)
|
||||
>>> ridge = Ridge().fit(X, y)
|
||||
>>> y_pred = ridge.predict(X)
|
||||
>>> disp = PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred)
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
|
||||
random_state = check_random_state(random_state)
|
||||
|
||||
n_samples = len(y_true)
|
||||
if isinstance(subsample, numbers.Integral):
|
||||
if subsample <= 0:
|
||||
raise ValueError(
|
||||
f"When an integer, subsample={subsample} should be positive."
|
||||
)
|
||||
elif isinstance(subsample, numbers.Real):
|
||||
if subsample <= 0 or subsample >= 1:
|
||||
raise ValueError(
|
||||
f"When a floating-point, subsample={subsample} should"
|
||||
" be in the (0, 1) range."
|
||||
)
|
||||
subsample = int(n_samples * subsample)
|
||||
|
||||
if subsample is not None and subsample < n_samples:
|
||||
indices = random_state.choice(np.arange(n_samples), size=subsample)
|
||||
y_true = _safe_indexing(y_true, indices, axis=0)
|
||||
y_pred = _safe_indexing(y_pred, indices, axis=0)
|
||||
|
||||
viz = cls(
|
||||
y_true=y_true,
|
||||
y_pred=y_pred,
|
||||
)
|
||||
|
||||
return viz.plot(
|
||||
ax=ax,
|
||||
kind=kind,
|
||||
scatter_kwargs=scatter_kwargs,
|
||||
line_kwargs=line_kwargs,
|
||||
)
|
||||
@ -0,0 +1,419 @@
|
||||
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
|
||||
from .._ranking import auc, roc_curve
|
||||
|
||||
|
||||
class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin):
|
||||
"""ROC Curve visualization.
|
||||
|
||||
It is recommend to use
|
||||
:func:`~sklearn.metrics.RocCurveDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.RocCurveDisplay.from_predictions` to create
|
||||
a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are
|
||||
stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fpr : ndarray
|
||||
False positive rate.
|
||||
|
||||
tpr : ndarray
|
||||
True positive rate.
|
||||
|
||||
roc_auc : float, default=None
|
||||
Area under ROC curve. If None, the roc_auc score is not shown.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, the estimator name is not shown.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class when computing the roc auc
|
||||
metrics. By default, `estimators.classes_[1]` is considered
|
||||
as the positive class.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
ROC Curve.
|
||||
|
||||
chance_level_ : matplotlib Artist or None
|
||||
The chance level line. It is `None` if the chance level is not plotted.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with ROC Curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
|
||||
(ROC) curve given an estimator and some data.
|
||||
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
|
||||
(ROC) curve given the true and predicted values.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
>>> from sklearn import metrics
|
||||
>>> y = np.array([0, 0, 1, 1])
|
||||
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
|
||||
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred)
|
||||
>>> roc_auc = metrics.auc(fpr, tpr)
|
||||
>>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
|
||||
... estimator_name='example estimator')
|
||||
>>> display.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=None):
|
||||
self.estimator_name = estimator_name
|
||||
self.fpr = fpr
|
||||
self.tpr = tpr
|
||||
self.roc_auc = roc_auc
|
||||
self.pos_label = pos_label
|
||||
|
||||
def plot(
|
||||
self,
|
||||
ax=None,
|
||||
*,
|
||||
name=None,
|
||||
plot_chance_level=False,
|
||||
chance_level_kw=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot visualization.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's ``plot``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
|
||||
not `None`, otherwise no labeling is shown.
|
||||
|
||||
plot_chance_level : bool, default=False
|
||||
Whether to plot the chance level.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
chance_level_kw : dict, default=None
|
||||
Keyword arguments to be passed to matplotlib's `plot` for rendering
|
||||
the chance level line.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
|
||||
|
||||
line_kwargs = {}
|
||||
if self.roc_auc is not None and name is not None:
|
||||
line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
|
||||
elif self.roc_auc is not None:
|
||||
line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
|
||||
elif name is not None:
|
||||
line_kwargs["label"] = name
|
||||
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
chance_level_line_kw = {
|
||||
"label": "Chance level (AUC = 0.5)",
|
||||
"color": "k",
|
||||
"linestyle": "--",
|
||||
}
|
||||
|
||||
if chance_level_kw is not None:
|
||||
chance_level_line_kw.update(**chance_level_kw)
|
||||
|
||||
(self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs)
|
||||
info_pos_label = (
|
||||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
|
||||
)
|
||||
|
||||
xlabel = "False Positive Rate" + info_pos_label
|
||||
ylabel = "True Positive Rate" + info_pos_label
|
||||
self.ax_.set(
|
||||
xlabel=xlabel,
|
||||
xlim=(-0.01, 1.01),
|
||||
ylabel=ylabel,
|
||||
ylim=(-0.01, 1.01),
|
||||
aspect="equal",
|
||||
)
|
||||
|
||||
if plot_chance_level:
|
||||
(self.chance_level_,) = self.ax_.plot(
|
||||
(0, 1), (0, 1), **chance_level_line_kw
|
||||
)
|
||||
else:
|
||||
self.chance_level_ = None
|
||||
|
||||
if "label" in line_kwargs or "label" in chance_level_line_kw:
|
||||
self.ax_.legend(loc="lower right")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
drop_intermediate=True,
|
||||
response_method="auto",
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
plot_chance_level=False,
|
||||
chance_level_kw=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a ROC Curve display from an estimator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
drop_intermediate : bool, default=True
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted ROC curve. This is useful in order to create lighter
|
||||
ROC curves.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class when computing the roc auc
|
||||
metrics. By default, `estimators.classes_[1]` is considered
|
||||
as the positive class.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
plot_chance_level : bool, default=False
|
||||
Whether to plot the chance level.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
chance_level_kw : dict, default=None
|
||||
Keyword arguments to be passed to matplotlib's `plot` for rendering
|
||||
the chance level line.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.RocCurveDisplay`
|
||||
The ROC Curve display.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_predictions : ROC Curve visualization given the
|
||||
probabilities of scores of a classifier.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import RocCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> RocCurveDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
y_pred, pos_label, name = cls._validate_and_get_response_values(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
response_method=response_method,
|
||||
pos_label=pos_label,
|
||||
name=name,
|
||||
)
|
||||
|
||||
return cls.from_predictions(
|
||||
y_true=y,
|
||||
y_pred=y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
name=name,
|
||||
ax=ax,
|
||||
pos_label=pos_label,
|
||||
plot_chance_level=plot_chance_level,
|
||||
chance_level_kw=chance_level_kw,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
sample_weight=None,
|
||||
drop_intermediate=True,
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
plot_chance_level=False,
|
||||
chance_level_kw=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot ROC curve given the true and predicted values.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Target scores, can either be probability estimates of the positive
|
||||
class, confidence values, or non-thresholded measure of decisions
|
||||
(as returned by “decision_function” on some classifiers).
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
drop_intermediate : bool, default=True
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted ROC curve. This is useful in order to create lighter
|
||||
ROC curves.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The label of the positive class. When `pos_label=None`, if `y_true`
|
||||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
|
||||
error will be raised.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC curve for labeling. If `None`, name will be set to
|
||||
`"Classifier"`.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
plot_chance_level : bool, default=False
|
||||
Whether to plot the chance level.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
chance_level_kw : dict, default=None
|
||||
Keyword arguments to be passed to matplotlib's `plot` for rendering
|
||||
the chance level line.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_estimator : ROC Curve visualization given an
|
||||
estimator and some data.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import RocCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> y_pred = clf.decision_function(X_test)
|
||||
>>> RocCurveDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
pos_label_validated, name = cls._validate_from_predictions_params(
|
||||
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
|
||||
)
|
||||
|
||||
fpr, tpr, _ = roc_curve(
|
||||
y_true,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
viz = cls(
|
||||
fpr=fpr,
|
||||
tpr=tpr,
|
||||
roc_auc=roc_auc,
|
||||
estimator_name=name,
|
||||
pos_label=pos_label_validated,
|
||||
)
|
||||
|
||||
return viz.plot(
|
||||
ax=ax,
|
||||
name=name,
|
||||
plot_chance_level=plot_chance_level,
|
||||
chance_level_kw=chance_level_kw,
|
||||
**kwargs,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,269 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.base import ClassifierMixin, clone
|
||||
from sklearn.calibration import CalibrationDisplay
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import (
|
||||
ConfusionMatrixDisplay,
|
||||
DetCurveDisplay,
|
||||
PrecisionRecallDisplay,
|
||||
PredictionErrorDisplay,
|
||||
RocCurveDisplay,
|
||||
)
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display",
|
||||
[CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay],
|
||||
)
|
||||
def test_display_curve_error_classifier(pyplot, data, data_binary, Display):
|
||||
"""Check that a proper error is raised when only binary classification is
|
||||
supported."""
|
||||
X, y = data
|
||||
X_binary, y_binary = data_binary
|
||||
clf = DecisionTreeClassifier().fit(X, y)
|
||||
|
||||
# Case 1: multiclass classifier with multiclass target
|
||||
msg = "Expected 'estimator' to be a binary classifier. Got 3 classes instead."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(clf, X, y)
|
||||
|
||||
# Case 2: multiclass classifier with binary target
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(clf, X_binary, y_binary)
|
||||
|
||||
# Case 3: binary classifier with multiclass target
|
||||
clf = DecisionTreeClassifier().fit(X_binary, y_binary)
|
||||
msg = "The target y is not binary. Got multiclass type of target."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(clf, X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display",
|
||||
[CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay],
|
||||
)
|
||||
def test_display_curve_error_regression(pyplot, data_binary, Display):
|
||||
"""Check that we raise an error with regressor."""
|
||||
|
||||
# Case 1: regressor
|
||||
X, y = data_binary
|
||||
regressor = DecisionTreeRegressor().fit(X, y)
|
||||
|
||||
msg = "Expected 'estimator' to be a binary classifier. Got DecisionTreeRegressor"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(regressor, X, y)
|
||||
|
||||
# Case 2: regression target
|
||||
classifier = DecisionTreeClassifier().fit(X, y)
|
||||
# Force `y_true` to be seen as a regression problem
|
||||
y = y + 0.5
|
||||
msg = "The target y is not binary. Got continuous type of target."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(classifier, X, y)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_predictions(y, regressor.fit(X, y).predict(X))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_method, msg",
|
||||
[
|
||||
(
|
||||
"predict_proba",
|
||||
"MyClassifier has none of the following attributes: predict_proba.",
|
||||
),
|
||||
(
|
||||
"decision_function",
|
||||
"MyClassifier has none of the following attributes: decision_function.",
|
||||
),
|
||||
(
|
||||
"auto",
|
||||
(
|
||||
"MyClassifier has none of the following attributes: predict_proba,"
|
||||
" decision_function."
|
||||
),
|
||||
),
|
||||
(
|
||||
"bad_method",
|
||||
"MyClassifier has none of the following attributes: bad_method.",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_error_no_response(
|
||||
pyplot,
|
||||
data_binary,
|
||||
response_method,
|
||||
msg,
|
||||
Display,
|
||||
):
|
||||
"""Check that a proper error is raised when the response method requested
|
||||
is not defined for the given trained classifier."""
|
||||
X, y = data_binary
|
||||
|
||||
class MyClassifier(ClassifierMixin):
|
||||
def fit(self, X, y):
|
||||
self.classes_ = [0, 1]
|
||||
return self
|
||||
|
||||
clf = MyClassifier().fit(X, y)
|
||||
|
||||
with pytest.raises(AttributeError, match=msg):
|
||||
Display.from_estimator(clf, X, y, response_method=response_method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_display_curve_estimator_name_multiple_calls(
|
||||
pyplot,
|
||||
data_binary,
|
||||
Display,
|
||||
constructor_name,
|
||||
):
|
||||
"""Check that passing `name` when calling `plot` will overwrite the original name
|
||||
in the legend."""
|
||||
X, y = data_binary
|
||||
clf_name = "my hand-crafted name"
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
y_pred = clf.predict_proba(X)[:, 1]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
disp = Display.from_estimator(clf, X, y, name=clf_name)
|
||||
else:
|
||||
disp = Display.from_predictions(y, y_pred, name=clf_name)
|
||||
assert disp.estimator_name == clf_name
|
||||
pyplot.close("all")
|
||||
disp.plot()
|
||||
assert clf_name in disp.line_.get_label()
|
||||
pyplot.close("all")
|
||||
clf_name = "another_name"
|
||||
disp.plot(name=clf_name)
|
||||
assert clf_name in disp.line_.get_label()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display):
|
||||
"""Check that a proper error is raised when the classifier is not
|
||||
fitted."""
|
||||
X, y = data_binary
|
||||
# clone since we parametrize the test and the classifier will be fitted
|
||||
# when testing the second and subsequent plotting function
|
||||
model = clone(clf)
|
||||
with pytest.raises(NotFittedError):
|
||||
Display.from_estimator(model, X, y)
|
||||
model.fit(X, y)
|
||||
disp = Display.from_estimator(model, X, y)
|
||||
assert model.__class__.__name__ in disp.line_.get_label()
|
||||
assert disp.estimator_name == model.__class__.__name__
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_n_samples_consistency(pyplot, data_binary, Display):
|
||||
"""Check the error raised when `y_pred` or `sample_weight` have inconsistent
|
||||
length."""
|
||||
X, y = data_binary
|
||||
classifier = DecisionTreeClassifier().fit(X, y)
|
||||
|
||||
msg = "Found input variables with inconsistent numbers of samples"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(classifier, X[:-2], y)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(classifier, X, y[:-2])
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(classifier, X, y, sample_weight=np.ones(X.shape[0] - 2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_error_pos_label(pyplot, data_binary, Display):
|
||||
"""Check consistence of error message when `pos_label` should be specified."""
|
||||
X, y = data_binary
|
||||
y = y + 10
|
||||
|
||||
classifier = DecisionTreeClassifier().fit(X, y)
|
||||
y_pred = classifier.predict_proba(X)[:, -1]
|
||||
msg = r"y_true takes value in {10, 11} and pos_label is not specified"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_predictions(y, y_pred)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display",
|
||||
[
|
||||
CalibrationDisplay,
|
||||
DetCurveDisplay,
|
||||
PrecisionRecallDisplay,
|
||||
RocCurveDisplay,
|
||||
PredictionErrorDisplay,
|
||||
ConfusionMatrixDisplay,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"constructor",
|
||||
["from_predictions", "from_estimator"],
|
||||
)
|
||||
def test_classifier_display_curve_named_constructor_return_type(
|
||||
pyplot, data_binary, Display, constructor
|
||||
):
|
||||
"""Check that named constructors return the correct type when subclassed.
|
||||
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/pull/27675
|
||||
"""
|
||||
X, y = data_binary
|
||||
|
||||
# This can be anything - we just need to check the named constructor return
|
||||
# type so the only requirement here is instantiating the class without error
|
||||
y_pred = y
|
||||
|
||||
classifier = LogisticRegression().fit(X, y)
|
||||
|
||||
class SubclassOfDisplay(Display):
|
||||
pass
|
||||
|
||||
if constructor == "from_predictions":
|
||||
curve = SubclassOfDisplay.from_predictions(y, y_pred)
|
||||
else: # constructor == "from_estimator"
|
||||
curve = SubclassOfDisplay.from_estimator(classifier, X, y)
|
||||
|
||||
assert isinstance(curve, SubclassOfDisplay)
|
||||
@ -0,0 +1,380 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import (
|
||||
assert_allclose,
|
||||
assert_array_equal,
|
||||
)
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC, SVR
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*"
|
||||
)
|
||||
|
||||
|
||||
def test_confusion_matrix_display_validation(pyplot):
|
||||
"""Check that we raise the proper error when validating parameters."""
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=5, random_state=0
|
||||
)
|
||||
|
||||
with pytest.raises(NotFittedError):
|
||||
ConfusionMatrixDisplay.from_estimator(SVC(), X, y)
|
||||
|
||||
regressor = SVR().fit(X, y)
|
||||
y_pred_regressor = regressor.predict(X)
|
||||
y_pred_classifier = SVC().fit(X, y).predict(X)
|
||||
|
||||
err_msg = "ConfusionMatrixDisplay.from_estimator only supports classifiers"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
ConfusionMatrixDisplay.from_estimator(regressor, X, y)
|
||||
|
||||
err_msg = "Mix type of y not allowed, got types"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
# Force `y_true` to be seen as a regression problem
|
||||
ConfusionMatrixDisplay.from_predictions(y + 0.5, y_pred_classifier)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
ConfusionMatrixDisplay.from_predictions(y, y_pred_regressor)
|
||||
|
||||
err_msg = "Found input variables with inconsistent numbers of samples"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
ConfusionMatrixDisplay.from_predictions(y, y_pred_classifier[::2])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("with_labels", [True, False])
|
||||
@pytest.mark.parametrize("with_display_labels", [True, False])
|
||||
def test_confusion_matrix_display_custom_labels(
|
||||
pyplot, constructor_name, with_labels, with_display_labels
|
||||
):
|
||||
"""Check the resulting plot when labels are given."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
ax = pyplot.gca()
|
||||
labels = [2, 1, 0, 3, 4] if with_labels else None
|
||||
display_labels = ["b", "d", "a", "e", "f"] if with_display_labels else None
|
||||
|
||||
cm = confusion_matrix(y, y_pred, labels=labels)
|
||||
common_kwargs = {
|
||||
"ax": ax,
|
||||
"display_labels": display_labels,
|
||||
"labels": labels,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
|
||||
if with_display_labels:
|
||||
expected_display_labels = display_labels
|
||||
elif with_labels:
|
||||
expected_display_labels = labels
|
||||
else:
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name) for name in expected_display_labels]
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("normalize", ["true", "pred", "all", None])
|
||||
@pytest.mark.parametrize("include_values", [True, False])
|
||||
def test_confusion_matrix_display_plotting(
|
||||
pyplot,
|
||||
constructor_name,
|
||||
normalize,
|
||||
include_values,
|
||||
):
|
||||
"""Check the overall plotting rendering."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
ax = pyplot.gca()
|
||||
cmap = "plasma"
|
||||
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
common_kwargs = {
|
||||
"normalize": normalize,
|
||||
"cmap": cmap,
|
||||
"ax": ax,
|
||||
"include_values": include_values,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
assert disp.ax_ == ax
|
||||
|
||||
if normalize == "true":
|
||||
cm = cm / cm.sum(axis=1, keepdims=True)
|
||||
elif normalize == "pred":
|
||||
cm = cm / cm.sum(axis=0, keepdims=True)
|
||||
elif normalize == "all":
|
||||
cm = cm / cm.sum()
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
import matplotlib as mpl
|
||||
|
||||
assert isinstance(disp.im_, mpl.image.AxesImage)
|
||||
assert disp.im_.get_cmap().name == cmap
|
||||
assert isinstance(disp.ax_, pyplot.Axes)
|
||||
assert isinstance(disp.figure_, pyplot.Figure)
|
||||
|
||||
assert disp.ax_.get_ylabel() == "True label"
|
||||
assert disp.ax_.get_xlabel() == "Predicted label"
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name) for name in expected_display_labels]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
if include_values:
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
fmt = ".2g"
|
||||
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
else:
|
||||
assert disp.text_ is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_confusion_matrix_display(pyplot, constructor_name):
|
||||
"""Check the behaviour of the default constructor without using the class
|
||||
methods."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
common_kwargs = {
|
||||
"normalize": None,
|
||||
"include_values": True,
|
||||
"cmap": "viridis",
|
||||
"xticks_rotation": 45.0,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 45.0)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
disp.plot(cmap="plasma")
|
||||
assert disp.im_.get_cmap().name == "plasma"
|
||||
|
||||
disp.plot(include_values=False)
|
||||
assert disp.text_ is None
|
||||
|
||||
disp.plot(xticks_rotation=90.0)
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 90.0)
|
||||
|
||||
disp.plot(values_format="e")
|
||||
expected_text = np.array([format(v, "e") for v in cm.ravel(order="C")])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
|
||||
|
||||
def test_confusion_matrix_contrast(pyplot):
|
||||
"""Check that the text color is appropriate depending on background."""
|
||||
|
||||
cm = np.eye(2) / 2
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray)
|
||||
# diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray_r)
|
||||
# diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# Regression test for #15920
|
||||
cm = np.array([[19, 34], [32, 58]])
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.Blues)
|
||||
min_color = pyplot.cm.Blues(0)
|
||||
max_color = pyplot.cm.Blues(255)
|
||||
assert_allclose(disp.text_[0, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[0, 1].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 1].get_color(), min_color)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])),
|
||||
LogisticRegression(),
|
||||
),
|
||||
],
|
||||
ids=["clf", "pipeline-clf", "pipeline-column_transformer-clf"],
|
||||
)
|
||||
def test_confusion_matrix_pipeline(pyplot, clf):
|
||||
"""Check the behaviour of the plotting with more complex pipeline."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
with pytest.raises(NotFittedError):
|
||||
ConfusionMatrixDisplay.from_estimator(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
y_pred = clf.predict(X)
|
||||
|
||||
disp = ConfusionMatrixDisplay.from_estimator(clf, X, y)
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_confusion_matrix_with_unknown_labels(pyplot, constructor_name):
|
||||
"""Check that when labels=None, the unique values in `y_pred` and `y_true`
|
||||
will be used.
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/pull/18405
|
||||
"""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
# create unseen labels in `y_true` not seen during fitting and not present
|
||||
# in 'classifier.classes_'
|
||||
y = y + 1
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
common_kwargs = {"labels": None}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
expected_labels = [str(i) for i in range(n_classes + 1)]
|
||||
assert_array_equal(expected_labels, display_labels)
|
||||
|
||||
|
||||
def test_colormap_max(pyplot):
|
||||
"""Check that the max color is used for the color of the text."""
|
||||
gray = pyplot.get_cmap("gray", 1024)
|
||||
confusion_matrix = np.array([[1.0, 0.0], [0.0, 1.0]])
|
||||
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix)
|
||||
disp.plot(cmap=gray)
|
||||
|
||||
color = disp.text_[1, 0].get_color()
|
||||
assert_allclose(color, [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
|
||||
def test_im_kw_adjust_vmin_vmax(pyplot):
|
||||
"""Check that im_kw passes kwargs to imshow"""
|
||||
|
||||
confusion_matrix = np.array([[0.48, 0.04], [0.08, 0.4]])
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix)
|
||||
disp.plot(im_kw=dict(vmin=0.0, vmax=0.8))
|
||||
|
||||
clim = disp.im_.get_clim()
|
||||
assert clim[0] == pytest.approx(0.0)
|
||||
assert clim[1] == pytest.approx(0.8)
|
||||
|
||||
|
||||
def test_confusion_matrix_text_kw(pyplot):
|
||||
"""Check that text_kw is passed to the text call."""
|
||||
font_size = 15.0
|
||||
X, y = make_classification(random_state=0)
|
||||
classifier = SVC().fit(X, y)
|
||||
|
||||
# from_estimator passes the font size
|
||||
disp = ConfusionMatrixDisplay.from_estimator(
|
||||
classifier, X, y, text_kw={"fontsize": font_size}
|
||||
)
|
||||
for text in disp.text_.reshape(-1):
|
||||
assert text.get_fontsize() == font_size
|
||||
|
||||
# plot adjusts plot to new font size
|
||||
new_font_size = 20.0
|
||||
disp.plot(text_kw={"fontsize": new_font_size})
|
||||
for text in disp.text_.reshape(-1):
|
||||
assert text.get_fontsize() == new_font_size
|
||||
|
||||
# from_predictions passes the font size
|
||||
y_pred = classifier.predict(X)
|
||||
disp = ConfusionMatrixDisplay.from_predictions(
|
||||
y, y_pred, text_kw={"fontsize": font_size}
|
||||
)
|
||||
for text in disp.text_.reshape(-1):
|
||||
assert text.get_fontsize() == font_size
|
||||
@ -0,0 +1,106 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import DetCurveDisplay, det_curve
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
def test_det_curve_display(
|
||||
pyplot, constructor_name, response_method, with_sample_weight, with_strings
|
||||
):
|
||||
X, y = load_iris(return_X_y=True)
|
||||
# Binarize the data with only the two first classes
|
||||
X, y = X[y < 2], y[y < 2]
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
if y_pred.ndim == 2:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
common_kwargs = {
|
||||
"name": lr.__class__.__name__,
|
||||
"alpha": 0.8,
|
||||
"sample_weight": sample_weight,
|
||||
"pos_label": pos_label,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = DetCurveDisplay.from_estimator(lr, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = DetCurveDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
fpr, fnr, _ = det_curve(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
assert_allclose(disp.fpr, fpr)
|
||||
assert_allclose(disp.fnr, fnr)
|
||||
|
||||
assert disp.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqal
|
||||
|
||||
assert isinstance(disp.line_, mpl.lines.Line2D)
|
||||
assert disp.line_.get_alpha() == 0.8
|
||||
assert isinstance(disp.ax_, mpl.axes.Axes)
|
||||
assert isinstance(disp.figure_, mpl.figure.Figure)
|
||||
assert disp.line_.get_label() == "LogisticRegression"
|
||||
|
||||
expected_pos_label = 1 if pos_label is None else pos_label
|
||||
expected_ylabel = f"False Negative Rate (Positive label: {expected_pos_label})"
|
||||
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
|
||||
assert disp.ax_.get_ylabel() == expected_ylabel
|
||||
assert disp.ax_.get_xlabel() == expected_xlabel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name, expected_clf_name",
|
||||
[
|
||||
("from_estimator", "LogisticRegression"),
|
||||
("from_predictions", "Classifier"),
|
||||
],
|
||||
)
|
||||
def test_det_curve_display_default_name(
|
||||
pyplot,
|
||||
constructor_name,
|
||||
expected_clf_name,
|
||||
):
|
||||
# Check the default name display in the figure when `name` is not provided
|
||||
X, y = load_iris(return_X_y=True)
|
||||
# Binarize the data with only the two first classes
|
||||
X, y = X[y < 2], y[y < 2]
|
||||
|
||||
lr = LogisticRegression().fit(X, y)
|
||||
y_pred = lr.predict_proba(X)[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
disp = DetCurveDisplay.from_estimator(lr, X, y)
|
||||
else:
|
||||
disp = DetCurveDisplay.from_predictions(y, y_pred)
|
||||
|
||||
assert disp.estimator_name == expected_clf_name
|
||||
assert disp.line_.get_label() == expected_clf_name
|
||||
@ -0,0 +1,361 @@
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_breast_cancer, make_classification
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import (
|
||||
PrecisionRecallDisplay,
|
||||
average_precision_score,
|
||||
precision_recall_curve,
|
||||
)
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import shuffle
|
||||
from sklearn.utils.fixes import trapezoid
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("drop_intermediate", [True, False])
|
||||
def test_precision_recall_display_plotting(
|
||||
pyplot, constructor_name, response_method, drop_intermediate
|
||||
):
|
||||
"""Check the overall plotting rendering."""
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
pos_label = 1
|
||||
|
||||
classifier = LogisticRegression().fit(X, y)
|
||||
classifier.fit(X, y)
|
||||
|
||||
y_pred = getattr(classifier, response_method)(X)
|
||||
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, pos_label]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
classifier,
|
||||
X,
|
||||
y,
|
||||
response_method=response_method,
|
||||
drop_intermediate=drop_intermediate,
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, pos_label=pos_label, drop_intermediate=drop_intermediate
|
||||
)
|
||||
|
||||
precision, recall, _ = precision_recall_curve(
|
||||
y, y_pred, pos_label=pos_label, drop_intermediate=drop_intermediate
|
||||
)
|
||||
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)
|
||||
|
||||
np.testing.assert_allclose(display.precision, precision)
|
||||
np.testing.assert_allclose(display.recall, recall)
|
||||
assert display.average_precision == pytest.approx(average_precision)
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
assert isinstance(display.line_, mpl.lines.Line2D)
|
||||
assert isinstance(display.ax_, mpl.axes.Axes)
|
||||
assert isinstance(display.figure_, mpl.figure.Figure)
|
||||
|
||||
assert display.ax_.get_xlabel() == "Recall (Positive label: 1)"
|
||||
assert display.ax_.get_ylabel() == "Precision (Positive label: 1)"
|
||||
assert display.ax_.get_adjustable() == "box"
|
||||
assert display.ax_.get_aspect() in ("equal", 1.0)
|
||||
assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01)
|
||||
|
||||
# plotting passing some new parameters
|
||||
display.plot(alpha=0.8, name="MySpecialEstimator")
|
||||
expected_label = f"MySpecialEstimator (AP = {average_precision:0.2f})"
|
||||
assert display.line_.get_label() == expected_label
|
||||
assert display.line_.get_alpha() == pytest.approx(0.8)
|
||||
|
||||
# Check that the chance level line is not plotted by default
|
||||
assert display.chance_level_ is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}])
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_precision_recall_chance_level_line(
|
||||
pyplot,
|
||||
chance_level_kw,
|
||||
constructor_name,
|
||||
):
|
||||
"""Check the chance level line plotting behavior."""
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
pos_prevalence = Counter(y)[1] / len(y)
|
||||
|
||||
lr = LogisticRegression()
|
||||
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
plot_chance_level=True,
|
||||
chance_level_kw=chance_level_kw,
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
plot_chance_level=True,
|
||||
chance_level_kw=chance_level_kw,
|
||||
)
|
||||
|
||||
import matplotlib as mpl # noqa
|
||||
|
||||
assert isinstance(display.chance_level_, mpl.lines.Line2D)
|
||||
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
|
||||
assert tuple(display.chance_level_.get_ydata()) == (pos_prevalence, pos_prevalence)
|
||||
|
||||
# Checking for chance level line styles
|
||||
if chance_level_kw is None:
|
||||
assert display.chance_level_.get_color() == "k"
|
||||
else:
|
||||
assert display.chance_level_.get_color() == "r"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name, default_label",
|
||||
[
|
||||
("from_estimator", "LogisticRegression (AP = {:.2f})"),
|
||||
("from_predictions", "Classifier (AP = {:.2f})"),
|
||||
],
|
||||
)
|
||||
def test_precision_recall_display_name(pyplot, constructor_name, default_label):
|
||||
"""Check the behaviour of the name parameters"""
|
||||
X, y = make_classification(n_classes=2, n_samples=100, random_state=0)
|
||||
pos_label = 1
|
||||
|
||||
classifier = LogisticRegression().fit(X, y)
|
||||
classifier.fit(X, y)
|
||||
|
||||
y_pred = classifier.predict_proba(X)[:, pos_label]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(classifier, X, y)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, pos_label=pos_label
|
||||
)
|
||||
|
||||
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)
|
||||
|
||||
# check that the default name is used
|
||||
assert display.line_.get_label() == default_label.format(average_precision)
|
||||
|
||||
# check that the name can be set
|
||||
display.plot(name="MySpecialEstimator")
|
||||
assert (
|
||||
display.line_.get_label()
|
||||
== f"MySpecialEstimator (AP = {average_precision:.2f})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_precision_recall_display_pipeline(pyplot, clf):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
with pytest.raises(NotFittedError):
|
||||
PrecisionRecallDisplay.from_estimator(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
display = PrecisionRecallDisplay.from_estimator(clf, X, y)
|
||||
assert display.estimator_name == clf.__class__.__name__
|
||||
|
||||
|
||||
def test_precision_recall_display_string_labels(pyplot):
|
||||
# regression test #15738
|
||||
cancer = load_breast_cancer()
|
||||
X, y = cancer.data, cancer.target_names[cancer.target]
|
||||
|
||||
lr = make_pipeline(StandardScaler(), LogisticRegression())
|
||||
lr.fit(X, y)
|
||||
for klass in cancer.target_names:
|
||||
assert klass in lr.classes_
|
||||
display = PrecisionRecallDisplay.from_estimator(lr, X, y)
|
||||
|
||||
y_pred = lr.predict_proba(X)[:, 1]
|
||||
avg_prec = average_precision_score(y, y_pred, pos_label=lr.classes_[1])
|
||||
|
||||
assert display.average_precision == pytest.approx(avg_prec)
|
||||
assert display.estimator_name == lr.__class__.__name__
|
||||
|
||||
err_msg = r"y_true takes value in {'benign', 'malignant'}"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
PrecisionRecallDisplay.from_predictions(y, y_pred)
|
||||
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, pos_label=lr.classes_[1]
|
||||
)
|
||||
assert display.average_precision == pytest.approx(avg_prec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"average_precision, estimator_name, expected_label",
|
||||
[
|
||||
(0.9, None, "AP = 0.90"),
|
||||
(None, "my_est", "my_est"),
|
||||
(0.8, "my_est2", "my_est2 (AP = 0.80)"),
|
||||
],
|
||||
)
|
||||
def test_default_labels(pyplot, average_precision, estimator_name, expected_label):
|
||||
"""Check the default labels used in the display."""
|
||||
precision = np.array([1, 0.5, 0])
|
||||
recall = np.array([0, 0.5, 1])
|
||||
display = PrecisionRecallDisplay(
|
||||
precision,
|
||||
recall,
|
||||
average_precision=average_precision,
|
||||
estimator_name=estimator_name,
|
||||
)
|
||||
display.plot()
|
||||
assert display.line_.get_label() == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_method):
|
||||
# check that we can provide the positive label and display the proper
|
||||
# statistics
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
# create an highly imbalanced version of the breast cancer dataset
|
||||
idx_positive = np.flatnonzero(y == 1)
|
||||
idx_negative = np.flatnonzero(y == 0)
|
||||
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
|
||||
X, y = X[idx_selected], y[idx_selected]
|
||||
X, y = shuffle(X, y, random_state=42)
|
||||
# only use 2 features to make the problem even harder
|
||||
X = X[:, :2]
|
||||
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
stratify=y,
|
||||
random_state=0,
|
||||
)
|
||||
|
||||
classifier = LogisticRegression()
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# sanity check to be sure the positive class is classes_[0] and that we
|
||||
# are betrayed by the class imbalance
|
||||
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
|
||||
|
||||
y_pred = getattr(classifier, response_method)(X_test)
|
||||
# we select the corresponding probability columns or reverse the decision
|
||||
# function otherwise
|
||||
y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0]
|
||||
y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
pos_label="cancer",
|
||||
response_method=response_method,
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_cancer,
|
||||
pos_label="cancer",
|
||||
)
|
||||
# we should obtain the statistics of the "cancer" class
|
||||
avg_prec_limit = 0.65
|
||||
assert display.average_precision < avg_prec_limit
|
||||
assert -trapezoid(display.precision, display.recall) < avg_prec_limit
|
||||
|
||||
# otherwise we should obtain the statistics of the "not cancer" class
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
response_method=response_method,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_not_cancer,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
avg_prec_limit = 0.95
|
||||
assert display.average_precision > avg_prec_limit
|
||||
assert -trapezoid(display.precision, display.recall) > avg_prec_limit
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_precision_recall_prevalence_pos_label_reusable(pyplot, constructor_name):
|
||||
# Check that even if one passes plot_chance_level=False the first time
|
||||
# one can still call disp.plot with plot_chance_level=True and get the
|
||||
# chance level line
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
|
||||
lr = LogisticRegression()
|
||||
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
lr, X, y, plot_chance_level=False
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, plot_chance_level=False
|
||||
)
|
||||
assert display.chance_level_ is None
|
||||
|
||||
import matplotlib as mpl # noqa
|
||||
|
||||
# When calling from_estimator or from_predictions,
|
||||
# prevalence_pos_label should have been set, so that directly
|
||||
# calling plot_chance_level=True should plot the chance level line
|
||||
display.plot(plot_chance_level=True)
|
||||
assert isinstance(display.chance_level_, mpl.lines.Line2D)
|
||||
|
||||
|
||||
def test_precision_recall_raise_no_prevalence(pyplot):
|
||||
# Check that raises correctly when plotting chance level with
|
||||
# no prvelance_pos_label is provided
|
||||
precision = np.array([1, 0.5, 0])
|
||||
recall = np.array([0, 0.5, 1])
|
||||
display = PrecisionRecallDisplay(precision, recall)
|
||||
|
||||
msg = (
|
||||
"You must provide prevalence_pos_label when constructing the "
|
||||
"PrecisionRecallDisplay object in order to plot the chance "
|
||||
"level line. Alternatively, you may use "
|
||||
"PrecisionRecallDisplay.from_estimator or "
|
||||
"PrecisionRecallDisplay.from_predictions "
|
||||
"to automatically set prevalence_pos_label"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
display.plot(plot_chance_level=True)
|
||||
@ -0,0 +1,161 @@
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.datasets import load_diabetes
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.metrics import PredictionErrorDisplay
|
||||
|
||||
X, y = load_diabetes(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regressor_fitted():
|
||||
return Ridge().fit(X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"regressor, params, err_type, err_msg",
|
||||
[
|
||||
(
|
||||
Ridge().fit(X, y),
|
||||
{"subsample": -1},
|
||||
ValueError,
|
||||
"When an integer, subsample=-1 should be",
|
||||
),
|
||||
(
|
||||
Ridge().fit(X, y),
|
||||
{"subsample": 20.0},
|
||||
ValueError,
|
||||
"When a floating-point, subsample=20.0 should be",
|
||||
),
|
||||
(
|
||||
Ridge().fit(X, y),
|
||||
{"subsample": -20.0},
|
||||
ValueError,
|
||||
"When a floating-point, subsample=-20.0 should be",
|
||||
),
|
||||
(
|
||||
Ridge().fit(X, y),
|
||||
{"kind": "xxx"},
|
||||
ValueError,
|
||||
"`kind` must be one of",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
|
||||
def test_prediction_error_display_raise_error(
|
||||
pyplot, class_method, regressor, params, err_type, err_msg
|
||||
):
|
||||
"""Check that we raise the proper error when making the parameters
|
||||
# validation."""
|
||||
with pytest.raises(err_type, match=err_msg):
|
||||
if class_method == "from_estimator":
|
||||
PredictionErrorDisplay.from_estimator(regressor, X, y, **params)
|
||||
else:
|
||||
y_pred = regressor.predict(X)
|
||||
PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred, **params)
|
||||
|
||||
|
||||
def test_from_estimator_not_fitted(pyplot):
|
||||
"""Check that we raise a `NotFittedError` when the passed regressor is not
|
||||
fit."""
|
||||
regressor = Ridge()
|
||||
with pytest.raises(NotFittedError, match="is not fitted yet."):
|
||||
PredictionErrorDisplay.from_estimator(regressor, X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("kind", ["actual_vs_predicted", "residual_vs_predicted"])
|
||||
def test_prediction_error_display(pyplot, regressor_fitted, class_method, kind):
|
||||
"""Check the default behaviour of the display."""
|
||||
if class_method == "from_estimator":
|
||||
display = PredictionErrorDisplay.from_estimator(
|
||||
regressor_fitted, X, y, kind=kind
|
||||
)
|
||||
else:
|
||||
y_pred = regressor_fitted.predict(X)
|
||||
display = PredictionErrorDisplay.from_predictions(
|
||||
y_true=y, y_pred=y_pred, kind=kind
|
||||
)
|
||||
|
||||
if kind == "actual_vs_predicted":
|
||||
assert_allclose(display.line_.get_xdata(), display.line_.get_ydata())
|
||||
assert display.ax_.get_xlabel() == "Predicted values"
|
||||
assert display.ax_.get_ylabel() == "Actual values"
|
||||
assert display.line_ is not None
|
||||
else:
|
||||
assert display.ax_.get_xlabel() == "Predicted values"
|
||||
assert display.ax_.get_ylabel() == "Residuals (actual - predicted)"
|
||||
assert display.line_ is not None
|
||||
|
||||
assert display.ax_.get_legend() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize(
|
||||
"subsample, expected_size",
|
||||
[(5, 5), (0.1, int(X.shape[0] * 0.1)), (None, X.shape[0])],
|
||||
)
|
||||
def test_plot_prediction_error_subsample(
|
||||
pyplot, regressor_fitted, class_method, subsample, expected_size
|
||||
):
|
||||
"""Check the behaviour of `subsample`."""
|
||||
if class_method == "from_estimator":
|
||||
display = PredictionErrorDisplay.from_estimator(
|
||||
regressor_fitted, X, y, subsample=subsample
|
||||
)
|
||||
else:
|
||||
y_pred = regressor_fitted.predict(X)
|
||||
display = PredictionErrorDisplay.from_predictions(
|
||||
y_true=y, y_pred=y_pred, subsample=subsample
|
||||
)
|
||||
assert len(display.scatter_.get_offsets()) == expected_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
|
||||
def test_plot_prediction_error_ax(pyplot, regressor_fitted, class_method):
|
||||
"""Check that we can pass an axis to the display."""
|
||||
_, ax = pyplot.subplots()
|
||||
if class_method == "from_estimator":
|
||||
display = PredictionErrorDisplay.from_estimator(regressor_fitted, X, y, ax=ax)
|
||||
else:
|
||||
y_pred = regressor_fitted.predict(X)
|
||||
display = PredictionErrorDisplay.from_predictions(
|
||||
y_true=y, y_pred=y_pred, ax=ax
|
||||
)
|
||||
assert display.ax_ is ax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
|
||||
def test_prediction_error_custom_artist(pyplot, regressor_fitted, class_method):
|
||||
"""Check that we can tune the style of the lines."""
|
||||
extra_params = {
|
||||
"kind": "actual_vs_predicted",
|
||||
"scatter_kwargs": {"color": "red"},
|
||||
"line_kwargs": {"color": "black"},
|
||||
}
|
||||
if class_method == "from_estimator":
|
||||
display = PredictionErrorDisplay.from_estimator(
|
||||
regressor_fitted, X, y, **extra_params
|
||||
)
|
||||
else:
|
||||
y_pred = regressor_fitted.predict(X)
|
||||
display = PredictionErrorDisplay.from_predictions(
|
||||
y_true=y, y_pred=y_pred, **extra_params
|
||||
)
|
||||
|
||||
assert display.line_.get_color() == "black"
|
||||
assert_allclose(display.scatter_.get_edgecolor(), [[1.0, 0.0, 0.0, 0.8]])
|
||||
|
||||
# create a display with the default values
|
||||
if class_method == "from_estimator":
|
||||
display = PredictionErrorDisplay.from_estimator(regressor_fitted, X, y)
|
||||
else:
|
||||
y_pred = regressor_fitted.predict(X)
|
||||
display = PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred)
|
||||
pyplot.close("all")
|
||||
|
||||
display.plot(**extra_params)
|
||||
assert display.line_.get_color() == "black"
|
||||
assert_allclose(display.scatter_.get_edgecolor(), [[1.0, 0.0, 0.0, 0.8]])
|
||||
@ -0,0 +1,315 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_breast_cancer, load_iris
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import RocCurveDisplay, auc, roc_curve
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import shuffle
|
||||
from sklearn.utils.fixes import trapezoid
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("drop_intermediate", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name, default_name",
|
||||
[
|
||||
("from_estimator", "LogisticRegression"),
|
||||
("from_predictions", "Classifier"),
|
||||
],
|
||||
)
|
||||
def test_roc_curve_display_plotting(
|
||||
pyplot,
|
||||
response_method,
|
||||
data_binary,
|
||||
with_sample_weight,
|
||||
drop_intermediate,
|
||||
with_strings,
|
||||
constructor_name,
|
||||
default_name,
|
||||
):
|
||||
"""Check the overall plotting behaviour."""
|
||||
X, y = data_binary
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
alpha=0.8,
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
fpr, tpr, _ = roc_curve(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
assert_allclose(display.roc_auc, auc(fpr, tpr))
|
||||
assert_allclose(display.fpr, fpr)
|
||||
assert_allclose(display.tpr, tpr)
|
||||
|
||||
assert display.estimator_name == default_name
|
||||
|
||||
import matplotlib as mpl # noqal
|
||||
|
||||
assert isinstance(display.line_, mpl.lines.Line2D)
|
||||
assert display.line_.get_alpha() == 0.8
|
||||
assert isinstance(display.ax_, mpl.axes.Axes)
|
||||
assert isinstance(display.figure_, mpl.figure.Figure)
|
||||
assert display.ax_.get_adjustable() == "box"
|
||||
assert display.ax_.get_aspect() in ("equal", 1.0)
|
||||
assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01)
|
||||
|
||||
expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})"
|
||||
assert display.line_.get_label() == expected_label
|
||||
|
||||
expected_pos_label = 1 if pos_label is None else pos_label
|
||||
expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})"
|
||||
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
|
||||
|
||||
assert display.ax_.get_ylabel() == expected_ylabel
|
||||
assert display.ax_.get_xlabel() == expected_xlabel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("plot_chance_level", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"chance_level_kw",
|
||||
[None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name",
|
||||
["from_estimator", "from_predictions"],
|
||||
)
|
||||
def test_roc_curve_chance_level_line(
|
||||
pyplot,
|
||||
data_binary,
|
||||
plot_chance_level,
|
||||
chance_level_kw,
|
||||
constructor_name,
|
||||
):
|
||||
"""Check the chance level line plotting behaviour."""
|
||||
X, y = data_binary
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
y_pred = getattr(lr, "predict_proba")(X)
|
||||
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
alpha=0.8,
|
||||
plot_chance_level=plot_chance_level,
|
||||
chance_level_kw=chance_level_kw,
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
alpha=0.8,
|
||||
plot_chance_level=plot_chance_level,
|
||||
chance_level_kw=chance_level_kw,
|
||||
)
|
||||
|
||||
import matplotlib as mpl # noqa
|
||||
|
||||
assert isinstance(display.line_, mpl.lines.Line2D)
|
||||
assert display.line_.get_alpha() == 0.8
|
||||
assert isinstance(display.ax_, mpl.axes.Axes)
|
||||
assert isinstance(display.figure_, mpl.figure.Figure)
|
||||
|
||||
if plot_chance_level:
|
||||
assert isinstance(display.chance_level_, mpl.lines.Line2D)
|
||||
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
|
||||
assert tuple(display.chance_level_.get_ydata()) == (0, 1)
|
||||
else:
|
||||
assert display.chance_level_ is None
|
||||
|
||||
# Checking for chance level line styles
|
||||
if plot_chance_level and chance_level_kw is None:
|
||||
assert display.chance_level_.get_color() == "k"
|
||||
assert display.chance_level_.get_linestyle() == "--"
|
||||
assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)"
|
||||
elif plot_chance_level:
|
||||
assert display.chance_level_.get_label() == chance_level_kw["label"]
|
||||
assert display.chance_level_.get_color() == chance_level_kw["color"]
|
||||
assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructor_name):
|
||||
"""Check the behaviour with complex pipeline."""
|
||||
X, y = data_binary
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
with pytest.raises(NotFittedError):
|
||||
RocCurveDisplay.from_estimator(clf, X, y)
|
||||
|
||||
clf.fit(X, y)
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(clf, X, y)
|
||||
name = clf.__class__.__name__
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(y, y)
|
||||
name = "Classifier"
|
||||
|
||||
assert name in display.line_.get_label()
|
||||
assert display.estimator_name == name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"roc_auc, estimator_name, expected_label",
|
||||
[
|
||||
(0.9, None, "AUC = 0.90"),
|
||||
(None, "my_est", "my_est"),
|
||||
(0.8, "my_est2", "my_est2 (AUC = 0.80)"),
|
||||
],
|
||||
)
|
||||
def test_roc_curve_display_default_labels(
|
||||
pyplot, roc_auc, estimator_name, expected_label
|
||||
):
|
||||
"""Check the default labels used in the display."""
|
||||
fpr = np.array([0, 0.5, 1])
|
||||
tpr = np.array([0, 0.5, 1])
|
||||
disp = RocCurveDisplay(
|
||||
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=estimator_name
|
||||
).plot()
|
||||
assert disp.line_.get_label() == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
|
||||
# check that we can provide the positive label and display the proper
|
||||
# statistics
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
# create an highly imbalanced
|
||||
idx_positive = np.flatnonzero(y == 1)
|
||||
idx_negative = np.flatnonzero(y == 0)
|
||||
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
|
||||
X, y = X[idx_selected], y[idx_selected]
|
||||
X, y = shuffle(X, y, random_state=42)
|
||||
# only use 2 features to make the problem even harder
|
||||
X = X[:, :2]
|
||||
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
stratify=y,
|
||||
random_state=0,
|
||||
)
|
||||
|
||||
classifier = LogisticRegression()
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# sanity check to be sure the positive class is classes_[0] and that we
|
||||
# are betrayed by the class imbalance
|
||||
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
|
||||
|
||||
y_pred = getattr(classifier, response_method)(X_test)
|
||||
# we select the corresponding probability columns or reverse the decision
|
||||
# function otherwise
|
||||
y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0]
|
||||
y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
pos_label="cancer",
|
||||
response_method=response_method,
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_cancer,
|
||||
pos_label="cancer",
|
||||
)
|
||||
|
||||
roc_auc_limit = 0.95679
|
||||
|
||||
assert display.roc_auc == pytest.approx(roc_auc_limit)
|
||||
assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
response_method=response_method,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_not_cancer,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
|
||||
assert display.roc_auc == pytest.approx(roc_auc_limit)
|
||||
assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)
|
||||
Reference in New Issue
Block a user