{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi-label data stratification\n",
"\n",
"With the development of more complex multi-label transformation methods the community realizes how much the quality of classification depends on how the data is split into train/test sets or into folds for parameter estimation. More questions appear on stackoverflow or [crossvalidated](https://datascience.stackexchange.com/questions/33076/how-can-i-perform-stratified-sampling-for-multi-label-multi-class-classification) concerning methods for multi-label stratification.\n",
"\n",
"For many reasons, described [here](http://lpis.csd.auth.gr/publications/sechidis-ecmlpkdd-2011.pdf) and [here](http://proceedings.mlr.press/v74/szyma%C5%84ski17a.html) traditional single-label approaches to stratifying data fail to provide balanced data set divisions which prevents classifiers from generalizing information. \n",
"\n",
"Some train/test splits don't include evidence for a given label at all in the train set. others disproportionately put even as much as 70% of label pair evidence in the test set, leaving the train set without proper evidence for generalizing conditional probabilities for label relations.\n",
"\n",
"You can also watch a great video presentation from ECML 2011 which explains this in depth:\n",
"\n",
"
\n",
"\n",
" \n",
"
On the Stratification of Multi-Label Data
\n",
"Grigorios Tsoumakas\n",
"
\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scikit-multilearn provides an implementation of iterative stratification which aims to provide well-balanced distribution of evidence of label relations up to a given order. To see what it means, let's load up some data. We'll be using the scene data set, both in divided and undivided variants, to illustrate the problem."
]
},
{
"cell_type": "code",
"execution_count": 263,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"scene:undivided - exists, not redownloading\n"
]
}
],
"source": [
"from skmultilearn.dataset import load_dataset\n",
"X,y, _, _ = load_dataset('scene', 'undivided')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at how many examples are available per label combination:"
]
},
{
"cell_type": "code",
"execution_count": 264,
"metadata": {},
"outputs": [],
"source": [
"from skmultilearn.model_selection.measures import get_combination_wise_output_matrix"
]
},
{
"cell_type": "code",
"execution_count": 265,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Counter({(0, 0): 427,\n",
" (0, 3): 1,\n",
" (0, 4): 38,\n",
" (0, 5): 19,\n",
" (1, 1): 364,\n",
" (2, 2): 397,\n",
" (2, 3): 24,\n",
" (2, 4): 14,\n",
" (3, 3): 433,\n",
" (3, 4): 76,\n",
" (3, 5): 6,\n",
" (4, 4): 533,\n",
" (4, 5): 1,\n",
" (5, 5): 431})"
]
},
"execution_count": 265,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Counter(combination for row in get_combination_wise_output_matrix(y.A, order=2) for combination in row)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load up the original division, to see how the set was split into train/test data in 2004, before multi-label stratification methods appeared."
]
},
{
"cell_type": "code",
"execution_count": 266,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"scene:train - exists, not redownloading\n",
"scene:test - exists, not redownloading\n"
]
}
],
"source": [
"_, original_y_train, _, _ = load_dataset('scene', 'train')\n",
"_, original_y_test, _, _ = load_dataset('scene', 'test')"
]
},
{
"cell_type": "code",
"execution_count": 267,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 268,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" (0, 0) | \n",
" (0, 3) | \n",
" (0, 4) | \n",
" (0, 5) | \n",
" (1, 1) | \n",
" (2, 2) | \n",
" (2, 3) | \n",
" (2, 4) | \n",
" (3, 3) | \n",
" (3, 4) | \n",
" (3, 5) | \n",
" (4, 4) | \n",
" (4, 5) | \n",
" (5, 5) | \n",
"
\n",
" \n",
" \n",
" \n",
" test | \n",
" 200.0 | \n",
" 1.0 | \n",
" 17.0 | \n",
" 7.0 | \n",
" 199.0 | \n",
" 200.0 | \n",
" 16.0 | \n",
" 8.0 | \n",
" 237.0 | \n",
" 49.0 | \n",
" 5.0 | \n",
" 256.0 | \n",
" 0.0 | \n",
" 207.0 | \n",
"
\n",
" \n",
" train | \n",
" 227.0 | \n",
" 0.0 | \n",
" 21.0 | \n",
" 12.0 | \n",
" 165.0 | \n",
" 197.0 | \n",
" 8.0 | \n",
" 6.0 | \n",
" 196.0 | \n",
" 27.0 | \n",
" 1.0 | \n",
" 277.0 | \n",
" 1.0 | \n",
" 224.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" (0, 0) (0, 3) (0, 4) (0, 5) (1, 1) (2, 2) (2, 3) (2, 4) (3, 3) \\\n",
"test 200.0 1.0 17.0 7.0 199.0 200.0 16.0 8.0 237.0 \n",
"train 227.0 0.0 21.0 12.0 165.0 197.0 8.0 6.0 196.0 \n",
"\n",
" (3, 4) (3, 5) (4, 4) (4, 5) (5, 5) \n",
"test 49.0 5.0 256.0 0.0 207.0 \n",
"train 27.0 1.0 277.0 1.0 224.0 "
]
},
"execution_count": 268,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame({\n",
" 'train': Counter(str(combination) for row in get_combination_wise_output_matrix(original_y_train.A, order=2) for combination in row), \n",
" 'test' : Counter(str(combination) for row in get_combination_wise_output_matrix(original_y_test.A, order=2) for combination in row)\n",
"}).T.fillna(0.0)"
]
},
{
"cell_type": "code",
"execution_count": 269,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1211, 1196)"
]
},
"execution_count": 269,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"original_y_train.shape[0], original_y_test.shape[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the split size is nearly identical, yet some label combination evidence is well balanced between the splits. While this is a toy case on a small data set, such phenomena are common in larger datasets. We would like to fix this. \n",
"\n",
"Let's load the iterative stratifier and divided the set again."
]
},
{
"cell_type": "code",
"execution_count": 270,
"metadata": {},
"outputs": [],
"source": [
"from skmultilearn.model_selection import iterative_train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 278,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train, X_test, y_test = iterative_train_test_split(X, y, test_size = 0.5)"
]
},
{
"cell_type": "code",
"execution_count": 279,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" (0, 0) | \n",
" (0, 3) | \n",
" (0, 4) | \n",
" (0, 5) | \n",
" (1, 1) | \n",
" (2, 2) | \n",
" (2, 3) | \n",
" (2, 4) | \n",
" (3, 3) | \n",
" (3, 4) | \n",
" (3, 5) | \n",
" (4, 4) | \n",
" (4, 5) | \n",
" (5, 5) | \n",
"
\n",
" \n",
" \n",
" \n",
" test | \n",
" 213.0 | \n",
" 0.0 | \n",
" 19.0 | \n",
" 9.0 | \n",
" 182.0 | \n",
" 199.0 | \n",
" 12.0 | \n",
" 7.0 | \n",
" 217.0 | \n",
" 38.0 | \n",
" 3.0 | \n",
" 267.0 | \n",
" 1.0 | \n",
" 215.0 | \n",
"
\n",
" \n",
" train | \n",
" 214.0 | \n",
" 1.0 | \n",
" 19.0 | \n",
" 10.0 | \n",
" 182.0 | \n",
" 198.0 | \n",
" 12.0 | \n",
" 7.0 | \n",
" 216.0 | \n",
" 38.0 | \n",
" 3.0 | \n",
" 266.0 | \n",
" 0.0 | \n",
" 216.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" (0, 0) (0, 3) (0, 4) (0, 5) (1, 1) (2, 2) (2, 3) (2, 4) (3, 3) \\\n",
"test 213.0 0.0 19.0 9.0 182.0 199.0 12.0 7.0 217.0 \n",
"train 214.0 1.0 19.0 10.0 182.0 198.0 12.0 7.0 216.0 \n",
"\n",
" (3, 4) (3, 5) (4, 4) (4, 5) (5, 5) \n",
"test 38.0 3.0 267.0 1.0 215.0 \n",
"train 38.0 3.0 266.0 0.0 216.0 "
]
},
"execution_count": 279,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame({\n",
" 'train': Counter(str(combination) for row in get_combination_wise_output_matrix(y_train.A, order=2) for combination in row), \n",
" 'test' : Counter(str(combination) for row in get_combination_wise_output_matrix(y_test.A, order=2) for combination in row)\n",
"}).T.fillna(0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the new division is much more balanced."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}