80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
"""assert functions from numpy and pandas testing
|
|
|
|
"""
|
|
from statsmodels.compat.pandas import testing as pdt
|
|
|
|
import numpy.testing as npt
|
|
import pandas
|
|
|
|
from statsmodels.tools.tools import Bunch
|
|
|
|
# Standard list for parsing tables
|
|
PARAM_LIST = ['params', 'bse', 'tvalues', 'pvalues']
|
|
|
|
|
|
def bunch_factory(attribute, columns):
|
|
"""
|
|
Generates a special purpose Bunch class
|
|
|
|
Parameters
|
|
----------
|
|
attribute: str
|
|
Attribute to access when splitting
|
|
columns: List[str]
|
|
List of names to use when splitting the columns of attribute
|
|
|
|
Notes
|
|
-----
|
|
After the class is initialized as a Bunch, the columne of attribute
|
|
are split so that Bunch has the keys in columns and
|
|
bunch[column[i]] = bunch[attribute][:, i]
|
|
"""
|
|
class FactoryBunch(Bunch):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if not hasattr(self, attribute):
|
|
raise AttributeError('{} is required and must be passed to '
|
|
'the constructor'.format(attribute))
|
|
for i, att in enumerate(columns):
|
|
self[att] = getattr(self, attribute)[:, i]
|
|
|
|
return FactoryBunch
|
|
|
|
|
|
ParamsTableTestBunch = bunch_factory('params_table', PARAM_LIST)
|
|
|
|
MarginTableTestBunch = bunch_factory('margins_table', PARAM_LIST)
|
|
|
|
|
|
class Holder:
|
|
"""
|
|
Test-focused class to simplify accessing values by attribute
|
|
"""
|
|
def __init__(self, **kwds):
|
|
self.__dict__.update(kwds)
|
|
|
|
def __str__(self):
|
|
ss = "\n".join(str(k) + " = " + str(v).replace('\n', '\n ')
|
|
for k, v in vars(self).items())
|
|
return ss
|
|
|
|
def __repr__(self):
|
|
# use repr for values including nested cases as in tost
|
|
ss = "\n".join(str(k) + " = " + repr(v).replace('\n', '\n ')
|
|
for k, v in vars(self).items())
|
|
ss = str(self.__class__) + "\n" + ss
|
|
return ss
|
|
|
|
|
|
# adjusted functions
|
|
|
|
def assert_equal(actual, desired, err_msg='', verbose=True, **kwds):
|
|
if isinstance(desired, pandas.Index):
|
|
pdt.assert_index_equal(actual, desired)
|
|
elif isinstance(desired, pandas.Series):
|
|
pdt.assert_series_equal(actual, desired, **kwds)
|
|
elif isinstance(desired, pandas.DataFrame):
|
|
pdt.assert_frame_equal(actual, desired, **kwds)
|
|
else:
|
|
npt.assert_equal(actual, desired, err_msg='', verbose=True)
|