some new features
This commit is contained in:
@ -0,0 +1,204 @@
|
||||
"""
|
||||
Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann
|
||||
"""
|
||||
from statsmodels.compat.python import lrange
|
||||
|
||||
import numpy as np
|
||||
|
||||
from statsmodels.graphics.plottools import rainbow
|
||||
import statsmodels.graphics.utils as utils
|
||||
|
||||
|
||||
def interaction_plot(x, trace, response, func="mean", ax=None, plottype='b',
|
||||
xlabel=None, ylabel=None, colors=None, markers=None,
|
||||
linestyles=None, legendloc='best', legendtitle=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Interaction plot for factor level statistics.
|
||||
|
||||
Note. If categorial factors are supplied levels will be internally
|
||||
recoded to integers. This ensures matplotlib compatibility. Uses
|
||||
a DataFrame to calculate an `aggregate` statistic for each level of the
|
||||
factor or group given by `trace`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array_like
|
||||
The `x` factor levels constitute the x-axis. If a `pandas.Series` is
|
||||
given its name will be used in `xlabel` if `xlabel` is None.
|
||||
trace : array_like
|
||||
The `trace` factor levels will be drawn as lines in the plot.
|
||||
If `trace` is a `pandas.Series` its name will be used as the
|
||||
`legendtitle` if `legendtitle` is None.
|
||||
response : array_like
|
||||
The reponse or dependent variable. If a `pandas.Series` is given
|
||||
its name will be used in `ylabel` if `ylabel` is None.
|
||||
func : function
|
||||
Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
|
||||
the response variable grouped by the trace levels.
|
||||
ax : axes, optional
|
||||
Matplotlib axes instance
|
||||
plottype : str {'line', 'scatter', 'both'}, optional
|
||||
The type of plot to return. Can be 'l', 's', or 'b'
|
||||
xlabel : str, optional
|
||||
Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
|
||||
will use the series names.
|
||||
ylabel : str, optional
|
||||
Label to use for `response`. Default is 'func of response'. If
|
||||
`response` is a `pandas.Series` it will use the series names.
|
||||
colors : list, optional
|
||||
If given, must have length == number of levels in trace.
|
||||
markers : list, optional
|
||||
If given, must have length == number of levels in trace
|
||||
linestyles : list, optional
|
||||
If given, must have length == number of levels in trace.
|
||||
legendloc : {None, str, int}
|
||||
Location passed to the legend command.
|
||||
legendtitle : {None, str}
|
||||
Title of the legend.
|
||||
**kwargs
|
||||
These will be passed to the plot command used either plot or scatter.
|
||||
If you want to control the overall plotting options, use kwargs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Figure
|
||||
The figure given by `ax.figure` or a new instance.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> np.random.seed(12345)
|
||||
>>> weight = np.random.randint(1,4,size=60)
|
||||
>>> duration = np.random.randint(1,3,size=60)
|
||||
>>> days = np.log(np.random.randint(1,30, size=60))
|
||||
>>> fig = interaction_plot(weight, duration, days,
|
||||
... colors=['red','blue'], markers=['D','^'], ms=10)
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.show()
|
||||
|
||||
.. plot::
|
||||
|
||||
import numpy as np
|
||||
from statsmodels.graphics.factorplots import interaction_plot
|
||||
np.random.seed(12345)
|
||||
weight = np.random.randint(1,4,size=60)
|
||||
duration = np.random.randint(1,3,size=60)
|
||||
days = np.log(np.random.randint(1,30, size=60))
|
||||
fig = interaction_plot(weight, duration, days,
|
||||
colors=['red','blue'], markers=['D','^'], ms=10)
|
||||
import matplotlib.pyplot as plt
|
||||
#plt.show()
|
||||
"""
|
||||
|
||||
from pandas import DataFrame
|
||||
fig, ax = utils.create_mpl_ax(ax)
|
||||
|
||||
response_name = ylabel or getattr(response, 'name', 'response')
|
||||
func_name = getattr(func, "__name__", str(func))
|
||||
ylabel = f'{func_name} of {response_name}'
|
||||
xlabel = xlabel or getattr(x, 'name', 'X')
|
||||
legendtitle = legendtitle or getattr(trace, 'name', 'Trace')
|
||||
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlabel(xlabel)
|
||||
|
||||
x_values = x_levels = None
|
||||
if isinstance(x[0], str):
|
||||
x_levels = [l for l in np.unique(x)]
|
||||
x_values = lrange(len(x_levels))
|
||||
x = _recode(x, dict(zip(x_levels, x_values)))
|
||||
|
||||
data = DataFrame(dict(x=x, trace=trace, response=response))
|
||||
plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
|
||||
|
||||
# return data
|
||||
# check plot args
|
||||
n_trace = len(plot_data['trace'].unique())
|
||||
|
||||
linestyles = ['-'] * n_trace if linestyles is None else linestyles
|
||||
markers = ['.'] * n_trace if markers is None else markers
|
||||
colors = rainbow(n_trace) if colors is None else colors
|
||||
|
||||
if len(linestyles) != n_trace:
|
||||
raise ValueError("Must be a linestyle for each trace level")
|
||||
if len(markers) != n_trace:
|
||||
raise ValueError("Must be a marker for each trace level")
|
||||
if len(colors) != n_trace:
|
||||
raise ValueError("Must be a color for each trace level")
|
||||
|
||||
if plottype == 'both' or plottype == 'b':
|
||||
for i, (values, group) in enumerate(plot_data.groupby('trace')):
|
||||
# trace label
|
||||
label = str(group['trace'].values[0])
|
||||
ax.plot(group['x'], group['response'], color=colors[i],
|
||||
marker=markers[i], label=label,
|
||||
linestyle=linestyles[i], **kwargs)
|
||||
elif plottype == 'line' or plottype == 'l':
|
||||
for i, (values, group) in enumerate(plot_data.groupby('trace')):
|
||||
# trace label
|
||||
label = str(group['trace'].values[0])
|
||||
ax.plot(group['x'], group['response'], color=colors[i],
|
||||
label=label, linestyle=linestyles[i], **kwargs)
|
||||
elif plottype == 'scatter' or plottype == 's':
|
||||
for i, (values, group) in enumerate(plot_data.groupby('trace')):
|
||||
# trace label
|
||||
label = str(group['trace'].values[0])
|
||||
ax.scatter(group['x'], group['response'], color=colors[i],
|
||||
label=label, marker=markers[i], **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError("Plot type %s not understood" % plottype)
|
||||
ax.legend(loc=legendloc, title=legendtitle)
|
||||
ax.margins(.1)
|
||||
|
||||
if all([x_levels, x_values]):
|
||||
ax.set_xticks(x_values)
|
||||
ax.set_xticklabels(x_levels)
|
||||
return fig
|
||||
|
||||
|
||||
def _recode(x, levels):
|
||||
""" Recode categorial data to int factor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array_like
|
||||
array like object supporting with numpy array methods of categorially
|
||||
coded data.
|
||||
levels : dict
|
||||
mapping of labels to integer-codings
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : instance numpy.ndarray
|
||||
"""
|
||||
from pandas import Series
|
||||
name = None
|
||||
index = None
|
||||
|
||||
if isinstance(x, Series):
|
||||
name = x.name
|
||||
index = x.index
|
||||
x = x.values
|
||||
|
||||
if x.dtype.type not in [np.str_, np.object_]:
|
||||
raise ValueError('This is not a categorial factor.'
|
||||
' Array of str type required.')
|
||||
|
||||
elif not isinstance(levels, dict):
|
||||
raise ValueError('This is not a valid value for levels.'
|
||||
' Dict required.')
|
||||
|
||||
elif not (np.unique(x) == np.unique(list(levels.keys()))).all():
|
||||
raise ValueError('The levels do not match the array values.')
|
||||
|
||||
else:
|
||||
out = np.empty(x.shape[0], dtype=int)
|
||||
for level, coding in levels.items():
|
||||
out[x == level] = coding
|
||||
|
||||
if name:
|
||||
out = Series(out, name=name, index=index)
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user