Files
2025-08-01 04:33:03 -04:00

80 lines
2.4 KiB
Python

"""assert functions from numpy and pandas testing
"""
from statsmodels.compat.pandas import testing as pdt
import numpy.testing as npt
import pandas
from statsmodels.tools.tools import Bunch
# Standard list for parsing tables
PARAM_LIST = ['params', 'bse', 'tvalues', 'pvalues']
def bunch_factory(attribute, columns):
"""
Generates a special purpose Bunch class
Parameters
----------
attribute: str
Attribute to access when splitting
columns: List[str]
List of names to use when splitting the columns of attribute
Notes
-----
After the class is initialized as a Bunch, the columne of attribute
are split so that Bunch has the keys in columns and
bunch[column[i]] = bunch[attribute][:, i]
"""
class FactoryBunch(Bunch):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not hasattr(self, attribute):
raise AttributeError('{} is required and must be passed to '
'the constructor'.format(attribute))
for i, att in enumerate(columns):
self[att] = getattr(self, attribute)[:, i]
return FactoryBunch
ParamsTableTestBunch = bunch_factory('params_table', PARAM_LIST)
MarginTableTestBunch = bunch_factory('margins_table', PARAM_LIST)
class Holder:
"""
Test-focused class to simplify accessing values by attribute
"""
def __init__(self, **kwds):
self.__dict__.update(kwds)
def __str__(self):
ss = "\n".join(str(k) + " = " + str(v).replace('\n', '\n ')
for k, v in vars(self).items())
return ss
def __repr__(self):
# use repr for values including nested cases as in tost
ss = "\n".join(str(k) + " = " + repr(v).replace('\n', '\n ')
for k, v in vars(self).items())
ss = str(self.__class__) + "\n" + ss
return ss
# adjusted functions
def assert_equal(actual, desired, err_msg='', verbose=True, **kwds):
if isinstance(desired, pandas.Index):
pdt.assert_index_equal(actual, desired)
elif isinstance(desired, pandas.Series):
pdt.assert_series_equal(actual, desired, **kwds)
elif isinstance(desired, pandas.DataFrame):
pdt.assert_frame_equal(actual, desired, **kwds)
else:
npt.assert_equal(actual, desired, err_msg='', verbose=True)