diff --git a/code/data/dictionaries/valid_countries.txt b/code/data/dictionaries/valid_countries.txt index 2572722..bc8269b 100644 --- a/code/data/dictionaries/valid_countries.txt +++ b/code/data/dictionaries/valid_countries.txt @@ -1,198 +1,198 @@ Myanmar Switzerland Zambia Turkey Bolivia Dominican Republic Vietnam Guyana Egypt Mauritius Micronesia, Fed. Sts. Sierra Leone Congo Republic Suriname Jamaica Tanzania Fiji Luxembourg Kiribati Benin Cabo Verde Moldova Australia Eritrea Ecuador Marshall Islands Niue Brunei Darussalam Palestine Oman Hungary Sri Lanka Cote d'Ivoire China Ukraine Ethiopia Slovakia Niger Cuba Belarus Bangladesh Central African Republic Iraq Burundi Cook Islands Equatorial Guinea Armenia St. Vincent and the Grenadines Somalia Chad Nicaragua Gambia Greece Russia Israel Bahamas Finland Syria Haiti Kazakhstan Poland Lesotho Qatar Senegal Latvia Bosnia and Herzegovina Netherlands Guinea-Bissau Timor-Leste DR Congo Italy Antigua and Barbuda Singapore Mozambique Eswatini Seychelles Maldives United Kingdom Morocco France Venezuela Austria Uzbekistan Solomon Islands Kyrgyz Republic Rwanda United States Uruguay European Union Norway Bahrain Comoros Tunisia Albania Liberia North Korea Nauru Pakistan Zimbabwe Afghanistan Guatemala South Africa Germany Vanuatu South Sudan Mali Gabon New Zealand Romania Belize Honduras Tonga Palau Papua New Guinea Togo Belgium Malawi Sweden Canada St. Lucia Tuvalu Monaco Chile Denmark Nepal Serbia Ghana Czech Republic South Korea Macedonia Madagascar Cambodia Turkmenistan Azerbaijan Croatia Burkina Faso Lithuania Sudan Thailand Ireland Grenada Iran Djibouti Barbados India Iceland Bulgaria Slovenia Dominica Georgia Philippines Jordan Colombia Costa Rica Botswana Lebanon Guinea San Marino Andorra Paraguay Samoa Tajikistan St. Kitts and Nevis Saudi Arabia Argentina Yemen Mexico Brazil Peru Sao Tome and Principe Cameroon Montenegro Estonia Angola Uganda Panama Nigeria Portugal Kuwait Mongolia El Salvador Algeria Namibia Spain Japan Malta Kenya Trinidad and Tobago Laos Bhutan Malaysia Liechtenstein Cyprus Indonesia Libya United Arab Emirates Mauritania -european union \ No newline at end of file +European Union \ No newline at end of file diff --git a/code/scripts/predict_interventions.ipynb b/code/scripts/predict_interventions.ipynb index 8fb0745..3adadd3 100644 --- a/code/scripts/predict_interventions.ipynb +++ b/code/scripts/predict_interventions.ipynb @@ -1,1010 +1,864 @@ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predictive modelling of interventions\n", "## Ridge regression\n", "\n", "This notebook is ................\n", "\n", "\n", "First of all, we import the necessary packets." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn import linear_model\n", "from sklearn import model_selection\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.metrics import accuracy_score\n", "\n", "# constants\n", "FIRST_COUNTRY_INDEX_WO_EXP = 12\n", "FIRST_COUNTRY_INDEX = FIRST_COUNTRY_INDEX_WO_EXP + 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Prepare the dataset\n", "The data is provided in pandas dataframes. If the necessary csv files are not available, they can be generated with the script 'prepare_intervention_data.py'. This data now needs to be converted into numpy array such that we can train our model." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9217\n", - "['Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola', 'Antigua and Barbuda', 'Argentina', 'Armenia', 'Australia', 'Austria', 'Azerbaijan', 'Bahamas', 'Bahrain', 'Bangladesh', 'Barbados', 'Belarus', 'Belgium', 'Belize', 'Benin', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'Brunei Darussalam', 'Bulgaria', 'Burkina Faso', 'Burundi', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada', 'Central African Republic', 'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo Republic', 'Cook Islands', 'Costa Rica', \"Cote d'Ivoire\", 'Croatia', 'Cuba', 'Cyprus', 'Czech Republic', 'DR Congo', 'Denmark', 'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea', 'Estonia', 'Eswatini', 'Ethiopia', 'Fiji', 'Finland', 'France', 'Gabon', 'Gambia', 'Georgia', 'Germany', 'Ghana', 'Greece', 'Grenada', 'Guatemala', 'Guinea', 'Guinea-Bissau', 'Guyana', 'Haiti', 'Honduras', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Iran', 'Iraq', 'Ireland', 'Israel', 'Italy', 'Jamaica', 'Japan', 'Jordan', 'Kazakhstan', 'Kenya', 'Kiribati', 'Kuwait', 'Kyrgyz Republic', 'Laos', 'Latvia', 'Lebanon', 'Lesotho', 'Liberia', 'Libya', 'Liechtenstein', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malawi', 'Malaysia', 'Maldives', 'Mali', 'Malta', 'Marshall Islands', 'Mauritania', 'Mauritius', 'Mexico', 'Micronesia, Fed. Sts.', 'Moldova', 'Monaco', 'Mongolia', 'Montenegro', 'Morocco', 'Mozambique', 'Myanmar', 'Namibia', 'Nauru', 'Nepal', 'Netherlands', 'New Zealand', 'Nicaragua', 'Niger', 'Nigeria', 'Niue', 'North Korea', 'Norway', 'Oman', 'Pakistan', 'Palau', 'Palestine', 'Panama', 'Papua New Guinea', 'Paraguay', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Qatar', 'Romania', 'Russia', 'Rwanda', 'Samoa', 'San Marino', 'Sao Tome and Principe', 'Saudi Arabia', 'Senegal', 'Serbia', 'Seychelles', 'Sierra Leone', 'Singapore', 'Slovakia', 'Slovenia', 'Solomon Islands', 'Somalia', 'South Africa', 'South Korea', 'South Sudan', 'Spain', 'Sri Lanka', 'St. Kitts and Nevis', 'St. Lucia', 'St. Vincent and the Grenadines', 'Sudan', 'Suriname', 'Sweden', 'Switzerland', 'Syria', 'Tajikistan', 'Tanzania', 'Thailand', 'Timor-Leste', 'Togo', 'Tonga', 'Trinidad and Tobago', 'Tunisia', 'Turkey', 'Turkmenistan', 'Tuvalu', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay', 'Uzbekistan', 'Vanuatu', 'Vatican', 'Venezuela', 'Vietnam', 'Yemen', 'Zambia', 'Zimbabwe', 'european union']\n" + "['Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola', 'Antigua and Barbuda', 'Argentina', 'Armenia', 'Australia', 'Austria', 'Azerbaijan', 'Bahamas', 'Bahrain', 'Bangladesh', 'Barbados', 'Belarus', 'Belgium', 'Belize', 'Benin', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'Brunei Darussalam', 'Bulgaria', 'Burkina Faso', 'Burundi', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada', 'Central African Republic', 'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo Republic', 'Cook Islands', 'Costa Rica', \"Cote d'Ivoire\", 'Croatia', 'Cuba', 'Cyprus', 'Czech Republic', 'DR Congo', 'Denmark', 'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea', 'Estonia', 'Eswatini', 'Ethiopia', 'European Union', 'European Union', 'Fiji', 'Finland', 'France', 'Gabon', 'Gambia', 'Georgia', 'Germany', 'Ghana', 'Greece', 'Grenada', 'Guatemala', 'Guinea', 'Guinea-Bissau', 'Guyana', 'Haiti', 'Honduras', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Iran', 'Iraq', 'Ireland', 'Israel', 'Italy', 'Jamaica', 'Japan', 'Jordan', 'Kazakhstan', 'Kenya', 'Kiribati', 'Kuwait', 'Kyrgyz Republic', 'Laos', 'Latvia', 'Lebanon', 'Lesotho', 'Liberia', 'Libya', 'Liechtenstein', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malawi', 'Malaysia', 'Maldives', 'Mali', 'Malta', 'Marshall Islands', 'Mauritania', 'Mauritius', 'Mexico', 'Micronesia, Fed. Sts.', 'Moldova', 'Monaco', 'Mongolia', 'Montenegro', 'Morocco', 'Mozambique', 'Myanmar', 'Namibia', 'Nauru', 'Nepal', 'Netherlands', 'New Zealand', 'Nicaragua', 'Niger', 'Nigeria', 'Niue', 'North Korea', 'Norway', 'Oman', 'Pakistan', 'Palau', 'Palestine', 'Panama', 'Papua New Guinea', 'Paraguay', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Qatar', 'Romania', 'Russia', 'Rwanda', 'Samoa', 'San Marino', 'Sao Tome and Principe', 'Saudi Arabia', 'Senegal', 'Serbia', 'Seychelles', 'Sierra Leone', 'Singapore', 'Slovakia', 'Slovenia', 'Solomon Islands', 'Somalia', 'South Africa', 'South Korea', 'South Sudan', 'Spain', 'Sri Lanka', 'St. Kitts and Nevis', 'St. Lucia', 'St. Vincent and the Grenadines', 'Sudan', 'Suriname', 'Sweden', 'Switzerland', 'Syria', 'Tajikistan', 'Tanzania', 'Thailand', 'Timor-Leste', 'Togo', 'Tonga', 'Trinidad and Tobago', 'Tunisia', 'Turkey', 'Turkmenistan', 'Tuvalu', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay', 'Uzbekistan', 'Vanuatu', 'Venezuela', 'Vietnam', 'Yemen', 'Zambia', 'Zimbabwe']\n" ] } ], "source": [ "data = pd.read_csv(\"../data/data_regression/dataset_interventions.csv\",\n", " encoding=\"utf-8-sig\")\n", "\n", "D = 214\n", "D_wo_exp = D - 3\n", "N = len(data)\n", "print(N)\n", "dataset = np.zeros((N, D), dtype=np.float64)\n", "dataset_without_exp = np.zeros((N, D_wo_exp), dtype=np.float64)\n", "\n", "dataset_without_exp[:,:FIRST_COUNTRY_INDEX_WO_EXP] = (data.loc[:,\"year\":\"woman_proportion\"]).to_numpy()\n", "dataset[:,:FIRST_COUNTRY_INDEX] = (data.loc[:,\"year\":\"experience score parties rate\"]).to_numpy()\n", "\n", "labelset = np.zeros((N,), dtype=np.float64)\n", "labelset[:] = (data.loc[:, \"nb_interventions\"]).to_numpy()\n", "\n", "# read the valid countries\n", "country_file = open(\"../data/dictionaries/valid_countries.txt\", \"r\")\n", "countries = country_file.readlines()\n", "countries = [c.replace(\"\\n\", \"\") for c in countries]\n", "countries = sorted(countries)\n", "print(countries)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The resting part of the dataset is the country. We need to write a function that returns the index of the country in a sorted list of the valid countries." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def get_country_index(country):\n", " if country in countries:\n", " return countries.index(country)\n", " else:\n", " # unknown country\n", " return len(countries)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To finalize the dataset, we use the defined function. for indices 12 to 209, a 1 means that the affiliation is this country, 0 means that it isn't." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "for i in range(N):\n", " country = data.iloc[i,1]\n", " dataset_without_exp[i, FIRST_COUNTRY_INDEX_WO_EXP + get_country_index(country)] = 1\n", " dataset[i, FIRST_COUNTRY_INDEX + get_country_index(country)] = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Split the data to training data and test data\n", "In a first step, I consider the first 80% of the samples as training data and the resting 20% as test data" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# shuffle everything\n", "np.random.seed(2020)\n", "p = np.random.permutation(N)\n", "shuffled_dataset_wo_exp = dataset_without_exp[p]\n", "shuffled_dataset = dataset[p]\n", "shuffled_labelset = labelset[p]\n", "\n", "SPLIT_IDX = 7400\n", "# seperate train and test data\n", "X_train_without_exp = shuffled_dataset_wo_exp[:SPLIT_IDX]\n", "X_train = shuffled_dataset[:SPLIT_IDX]\n", "Y_train = shuffled_labelset[:SPLIT_IDX]\n", "\n", "X_test_without_exp = shuffled_dataset_wo_exp[SPLIT_IDX:]\n", "X_test = shuffled_dataset[SPLIT_IDX:]\n", "Y_test = shuffled_labelset[SPLIT_IDX:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Baseline models\n", "To have a upper bound for the performance of our models, we introduce two baseline models. The first one consists in predicting always 0 interventions." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 91.79\n" ] } ], "source": [ "n = Y_test.shape\n", "test_predict_baseline_zero = np.zeros(n)\n", "baseline_zero_mse = mean_squared_error(Y_test, test_predict_baseline_zero)\n", "print('Mean squared error: %.2f'\n", " % baseline_zero_mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another baseline is to predict for each country the average number of interventions done in the training data meetings." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 29.33\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":10: RuntimeWarning: invalid value encountered in double_scalars\n", + " avg = np.sum(train_samples_this_country) / len(train_samples_this_country)\n" + ] } ], "source": [ "test_predict_baseline_avg = np.zeros(n)\n", "\n", "# fill in a list with the averages\n", "intervention_averages = []\n", " \n", "# predict accordingly\n", "for i in range(len(countries)):\n", " index = FIRST_COUNTRY_INDEX + i\n", " train_samples_this_country = (Y_train[X_train[:,index] == 1])\n", " avg = np.sum(train_samples_this_country) / len(train_samples_this_country)\n", " intervention_averages.append(avg)\n", " \n", " test_predict_baseline_avg[X_test[:,index] == 1] = avg\n", " \n", "baseline_avg_mse = mean_squared_error(Y_test, test_predict_baseline_avg)\n", "print('Mean squared error: %.2f'\n", " % baseline_avg_mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, this result is equal to a simple linear model without global bias that only works on the country data." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 29.33\n" ] } ], "source": [ "# baseline data\n", "X_train_baseline = X_train[:,FIRST_COUNTRY_INDEX:]\n", "X_test_baseline = X_test[:,FIRST_COUNTRY_INDEX:]\n", "\n", "reg = linear_model.LinearRegression(fit_intercept=False)\n", "reg.fit(X_train_baseline, Y_train)\n", "test_predict = reg.predict(X_test_baseline)\n", "intervention_mse = mean_squared_error(Y_test, test_predict)\n", "print('Mean squared error: %.2f'\n", " % intervention_mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Ridge regression on all the data without experience\n", "Now, we can train the actual first model, conventional ridge regression that expects a gaussion distribution. We first use crossvalidation to determine the best regularizer." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The best regularizer is lambda = 0.41842885079015846\n" + "The best regularizer is lambda = 0.43016357581067904\n" ] } ], "source": [ "# cross validation to determine regularizer\n", "reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train_without_exp, Y_train)\n", "lambda_ = reg.alpha_\n", "print(f\"The best regularizer is lambda = {lambda_}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we train the model with the optimal lambda." ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "3.182209447564386\n", - "[-7.14734979e-02 6.52838466e-03 -3.82883323e-01 -2.58329613e-01\n", - " -6.02158216e-01 -3.41885362e-01 1.41709323e+00 6.49253475e-01\n", - " -6.03373961e-01 -2.60599551e-01 -1.87742671e-03 3.91760982e-02\n", - " -1.20352605e+00 -1.80764310e+00 -4.61710764e-01 -1.18954142e+00\n", - " -1.64719167e+00 -1.43437913e+00 2.73780221e+00 -1.68647639e+00\n", - " 1.64640285e+01 -1.57773210e+00 -1.71630673e+00 -1.61822630e+00\n", - " -1.53020504e+00 -2.41328224e-02 -8.90536988e-01 -4.49662593e-01\n", - " -1.76334138e+00 -1.47693523e+00 -1.33288221e+00 -1.34516588e+00\n", - " 3.71749718e+00 -1.45524913e+00 -1.39919608e+00 8.24206642e+00\n", - " -1.34283398e+00 -1.37468227e+00 -1.11016309e+00 -1.52596181e+00\n", - " -1.61622650e+00 -1.63972695e+00 -1.59446340e+00 1.30366912e+01\n", - " -1.30727254e+00 -1.65013704e+00 -3.07700340e-01 4.56725136e+01\n", - " 1.63341814e+00 -1.88830103e+00 -1.88987664e+00 -1.50883735e+00\n", - " -5.89504743e-01 -1.98634379e+00 -1.21660596e+00 -1.29859388e+00\n", - " -1.53582039e+00 -1.62809156e+00 -1.90474321e+00 -1.61380950e+00\n", - " -1.73130243e+00 -1.65958268e+00 -1.39042754e+00 -7.65658827e-01\n", - " 3.81567377e-01 -1.17076128e+00 -1.44003840e+00 -1.81434078e+00\n", - " -1.82717214e+00 -1.63001318e+00 -1.49271095e+00 -1.50313594e+00\n", - " -1.62429431e+00 -1.03366030e+00 -1.37387579e+00 -7.71079345e-01\n", - " -1.61888067e+00 -1.04134239e+00 -4.71331568e-01 -1.80200858e+00\n", - " -4.94963062e-01 -1.16321999e+00 -1.87865814e+00 -1.78254684e+00\n", - " -1.51290163e+00 -1.59381722e+00 -1.39826699e+00 -1.05733646e+00\n", - " -5.18016798e-01 7.50358743e+00 6.42743517e-01 1.05322086e+00\n", - " -1.22525172e+00 -1.70830773e+00 -1.71539933e+00 -1.52047812e+00\n", - " -1.39296340e+00 1.52986347e+01 -1.49541387e+00 -8.45856296e-01\n", - " -7.15909293e-01 -1.42981644e+00 6.50272670e-01 -1.43734580e+00\n", - " -1.79837321e+00 -1.71971578e+00 -1.54730242e+00 -1.82561628e+00\n", - " -1.57221936e+00 -1.37859240e+00 -1.55967408e+00 -1.65595390e+00\n", - " -1.61918234e+00 -1.54434897e+00 -1.63753495e+00 -1.52894298e+00\n", - " 8.41173064e-01 -1.37581537e+00 -1.42586805e+00 -1.63174046e+00\n", - " 9.54034713e-01 -1.16595484e+00 -1.05208959e+00 1.10228282e+00\n", - " -3.40244480e-01 -1.73882737e+00 -1.73931887e+00 -1.81260474e+00\n", - " -1.44758038e+00 -1.71094120e+00 -1.65010398e+00 -1.70391805e+00\n", - " -1.39478881e+00 -1.43233807e+00 -1.23430055e+00 -1.01675881e+00\n", - " 7.83271854e+00 -6.79004748e-01 -1.88887746e+00 1.19735370e-01\n", - " -1.68457414e+00 -1.67053555e+00 7.85666201e+00 -1.11010248e+00\n", - " -1.37171567e-01 -1.54193034e+00 -9.29120056e-01 -1.15191027e+00\n", - " 9.10523842e-01 -1.40703035e+00 -1.77564729e-01 2.97143869e+00\n", - " -8.08600533e-01 -1.81469245e+00 -1.02633692e+00 -1.53054706e+00\n", - " 5.22942025e+00 -1.60813880e+00 1.06096634e-01 -1.84976248e+00\n", - " -1.68504983e+00 1.32175727e+01 -7.08236565e-01 -1.50126559e+00\n", - " -1.65639243e+00 -1.35251530e+00 -3.54015492e-01 -1.71816176e+00\n", - " -1.27435533e+00 -1.52911872e+00 -1.36512654e+00 5.02279970e+00\n", - " 5.15304967e-01 -1.11058517e+00 -1.20162175e+00 -1.51456943e+00\n", - " -1.62411588e+00 -1.42120300e+00 -1.67294335e+00 -1.14383697e+00\n", - " -1.60597939e+00 -1.49467714e+00 8.65828845e+00 -1.44860754e+00\n", - " -1.38915177e+00 -2.22540822e-01 -7.05083594e-01 -1.06000380e+00\n", - " -1.75697172e+00 -1.64382416e+00 -1.11330094e+00 -1.72002938e+00\n", - " -7.47629813e-01 -1.76647072e+00 3.91561381e+00 -3.36183796e-01\n", - " -8.36675496e-01 -1.07905006e+00 -2.04345113e+00 5.31106523e+01\n", - " -3.83528384e-01 -1.68751094e+00 -1.76098842e+00 -1.88344212e+00\n", - " 1.33494068e+00 -1.54592895e+00 -1.67543079e+00 -1.14629255e+00\n", - " -7.78159268e-01 -1.93051315e+00 -2.27439002e+00]\n" + "3.1901302350423615\n", + "[-7.14120252e-02 6.60173200e-03 -3.82140800e-01 -2.46897866e-01\n", + " -5.94272034e-01 -3.33064352e-01 1.37588809e+00 6.52290537e-01\n", + " -6.00923896e-01 -2.53020482e-01 -2.25345629e-03 3.99739767e-02\n", + " -1.22227069e+00 -1.82666613e+00 -4.81319276e-01 -1.20327513e+00\n", + " -1.66545548e+00 -1.45410213e+00 2.71809137e+00 -1.70462142e+00\n", + " 1.64380064e+01 -1.59832398e+00 -1.73408519e+00 -1.63677519e+00\n", + " -1.54726449e+00 -4.41026726e-02 -9.10089342e-01 -4.68364939e-01\n", + " -1.78228197e+00 -1.49469663e+00 -1.35226888e+00 -1.36272750e+00\n", + " 3.69820465e+00 -1.47230968e+00 -1.41822641e+00 8.21725633e+00\n", + " -1.36228025e+00 -1.39434270e+00 -1.12850913e+00 -1.54448677e+00\n", + " -1.63449788e+00 -1.65991367e+00 -1.61171053e+00 1.30094802e+01\n", + " -1.32576815e+00 -1.66881530e+00 -3.27231731e-01 4.56394989e+01\n", + " 1.61297924e+00 -1.90647670e+00 -1.90866284e+00 -1.52628124e+00\n", + " -6.08536015e-01 -2.00680603e+00 -1.23532722e+00 -1.31802798e+00\n", + " -1.55475789e+00 -1.64761810e+00 -1.92414128e+00 -1.63482173e+00\n", + " -1.75047485e+00 -1.67738098e+00 -1.40897446e+00 -7.85235275e-01\n", + " 3.61868098e-01 -1.19047560e+00 -1.45979952e+00 -1.83309881e+00\n", + " -1.84686535e+00 -1.64841110e+00 -1.51138306e+00 0.00000000e+00\n", + " 0.00000000e+00 -1.52189501e+00 -1.64560317e+00 -1.05588588e+00\n", + " -1.39266183e+00 -7.90186199e-01 -1.63761675e+00 -1.06432202e+00\n", + " -4.91026168e-01 -1.82033415e+00 -5.14292152e-01 -1.18233485e+00\n", + " -1.89644777e+00 -1.79951787e+00 -1.53102507e+00 -1.61330850e+00\n", + " -1.41527559e+00 -1.07661299e+00 -5.37694838e-01 7.48095214e+00\n", + " 6.19985232e-01 1.03385441e+00 -1.24421777e+00 -1.72826381e+00\n", + " -1.73478526e+00 -1.54141921e+00 -1.41105880e+00 1.52707445e+01\n", + " -1.51358688e+00 -8.64109515e-01 -7.36153034e-01 -1.44924517e+00\n", + " 6.32904669e-01 -1.45634704e+00 -1.81806067e+00 -1.73939864e+00\n", + " -1.56573334e+00 -1.84336793e+00 -1.59140134e+00 -1.39535796e+00\n", + " -1.57873912e+00 -1.67540369e+00 -1.63775388e+00 -1.56376189e+00\n", + " -1.65660141e+00 -1.54867973e+00 8.21498673e-01 -1.39576120e+00\n", + " -1.44333079e+00 -1.65002440e+00 9.33884765e-01 -1.18392388e+00\n", + " -1.07066588e+00 1.08352942e+00 -3.59607422e-01 -1.75721231e+00\n", + " -1.75614820e+00 -1.83159110e+00 -1.46524218e+00 -1.73297949e+00\n", + " -1.66973682e+00 -1.72343421e+00 -1.41373916e+00 -1.45024701e+00\n", + " -1.25380233e+00 -1.03778672e+00 7.80945382e+00 -6.97805121e-01\n", + " -1.90689830e+00 1.00204959e-01 -1.70263005e+00 -1.68521853e+00\n", + " 7.83308085e+00 -1.12959868e+00 -1.56284587e-01 -1.56092373e+00\n", + " -9.47300680e-01 -1.16953464e+00 8.90899283e-01 -1.42552782e+00\n", + " -1.97512573e-01 2.95042516e+00 -8.29222217e-01 -1.83455820e+00\n", + " -1.04487435e+00 -1.55036868e+00 5.20847471e+00 -1.62599195e+00\n", + " 8.67960462e-02 -1.86287916e+00 -1.70351478e+00 1.31969162e+01\n", + " -7.27395236e-01 -1.52021097e+00 -1.67556876e+00 -1.37164381e+00\n", + " -3.74919190e-01 -1.73794237e+00 -1.29319626e+00 -1.54786588e+00\n", + " -1.38258546e+00 4.99885986e+00 4.94426400e-01 -1.12980470e+00\n", + " -1.22306656e+00 -1.53353193e+00 -1.64305328e+00 -1.44041827e+00\n", + " -1.69158106e+00 -1.16319553e+00 -1.62400580e+00 -1.51606916e+00\n", + " 8.63611424e+00 -1.46740249e+00 -1.40676167e+00 -2.42489283e-01\n", + " -7.25759727e-01 -1.07857631e+00 -1.77372625e+00 -1.66302307e+00\n", + " -1.13248700e+00 -1.73855687e+00 -7.69448205e-01 -1.78538213e+00\n", + " 3.89446798e+00 -3.56418948e-01 -8.55712292e-01 -1.09887879e+00\n", + " -2.06534896e+00 5.30717614e+01 -4.03126586e-01 -1.70509787e+00\n", + " -1.77945281e+00 1.31438183e+00 -1.56584114e+00 -1.69269384e+00\n", + " -1.16623707e+00 -7.98396870e-01 -2.26642611e+00]\n" ] } ], "source": [ "reg = linear_model.Ridge(alpha=lambda_)\n", "reg.fit(X_train_without_exp, Y_train)\n", "w0 = reg.intercept_\n", "W = reg.coef_\n", "print(w0)\n", "print(W)\n", "np.savetxt(\"W_.csv\", W, delimiter=\",\")\n", "d = pd.DataFrame(countries)\n", "d.to_csv(\"countries.csv\")" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 29.27\n" ] } ], "source": [ "test_predict = reg.predict(X_test_without_exp)\n", "print('Mean squared error: %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we notice, the mean square error is basically equal to the one in the baseline model. The additional dimensions do not help to improve the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Linear Model with logarithmic transformation\n", "The reason for the failure of the normal linear model is that our Y doesn't follow a Gaussian distribution, but a logarithmic. We thus try the same model but with a transformed Y. We transform the y's accordingly: \n", "$y' = log(c + y)$ with $c > 0$.\n", "We try different values for c. First, c = 1 (which preserves y = 0 to be y' = 0)" ] }, { "cell_type": "code", - "execution_count": 200, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "c = 10\n", - "Lambda is 0.3447764054734464\n", - "[2.36003016 2.29673558 2.41597337 ... 2.98163425 2.30944377 3.87097441]\n", - "[ 0.59127089 -0.05832438 1.20066747 ... 9.72001773 0.06882252\n", - " 37.98912433]\n", - "[2.30258509 2.30258509 2.39789527 ... 2.30258509 2.30258509 3.8501476 ]\n", - "[ 0. 0. 1. ... 0. 0. 37.]\n", - "Mean squared error : 34.31\n", - "c = 1000\n", - "Lambda is 0.4070142453219439\n", - "[6.90866435 6.90759974 6.90938658 ... 6.91880515 6.90767333 6.9534976 ]\n", - "[ 0.9094879 -0.15552457 1.63262819 ... 11.11114723 -0.08194264\n", - " 46.80463149]\n", - "[6.90775528 6.90775528 6.90875478 ... 6.90775528 6.90775528 6.94408721]\n", - "[ 0. 0. 1. ... 0. 0. 37.]\n", - "Mean squared error : 29.44\n", - "c = 2000\n", - "Lambda is 0.4070142453219439\n", - "[7.60136086 7.60082379 7.60172055 ... 7.606452 7.60085931 7.62413461]\n", - "[ 0.91701708 -0.15734237 1.63685623 ... 11.12993688 -0.08628893\n", - " 47.00824055]\n", - "[7.60090246 7.60090246 7.60140233 ... 7.60090246 7.60090246 7.61923342]\n", - "[ 0. 0. 1. ... 0. 0. 37.]\n", - "Mean squared error : 29.35\n", - "c = 5000\n", - "Lambda is 0.41842885079015846\n", - "[8.51737757 8.51716146 8.51752098 ... 8.51941864 8.51717549 8.52657378]\n", - "[ 0.92196911 -0.15863347 1.63922619 ... 11.13962064 -0.08850618\n", - " 47.12363259]\n", - "[8.51719319 8.51719319 8.51739317 ... 8.51719319 8.51719319 8.52456595]\n", - "[ 0. 0. 1. ... 0. 0. 37.]\n", - "Mean squared error : 29.30\n", - "c = 10000\n", - "Lambda is 0.41842885079015846\n", - "[9.21043273 9.21032447 9.21050437 ... 9.2114541 9.21033143 9.215046 ]\n", - "[ 0.92360649 -0.15901823 1.64008583 ... 11.14348634 -0.08942983\n", - " 47.16715676]\n", - "[9.21034037 9.21034037 9.21044037 ... 9.21034037 9.21034037 9.21403354]\n", - "[ 0. 0. 1. ... 0. 0. 37.]\n", - "Mean squared error : 29.29\n", - "c = 20000\n", - "Lambda is 0.41842885079015846\n", - "[9.90353377 9.90347959 9.90356958 ... 9.90404467 9.90348306 9.90584423]\n", - "[ 0.92443436 -0.15921203 1.64051656 ... 11.14542617 -0.08989534\n", - " 47.18910678]\n", - "[9.90348755 9.90348755 9.90353755 ... 9.90348755 9.90348755 9.90533584]\n", - "[ 0. 0. 1. ... 0. 0. 37.]\n", - "Mean squared error : 29.28\n" - ] - } - ], + "outputs": [], "source": [ - "cs = [10, 1000, 2000, 5000, 10000, 20000]\n", + "cs = [0.01, 10, 1000, 2000, 5000, 10000, 20000]\n", "\n", "for c in cs:\n", " print(f\"c = {c}\")\n", " # logarithmic transformation\n", " Y_train_transf = np.log(c + Y_train)\n", " Y_test_transf = np.log(c + Y_test)\n", "\n", " # crossvalidation for lambda\n", " reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", " reg.fit(X_train_without_exp, Y_train_transf)\n", " lambda_ = reg.alpha_\n", " print(f\"Lambda is {lambda_}\")\n", " test_predict_transf = reg.predict(X_test_without_exp)\n", " # transform the output back\n", " test_predict = np.exp(test_predict_transf) - c\n", " \n", - " print(test_predict_transf)\n", - " print(test_predict)\n", - " print(Y_test_transf)\n", - " print(Y_test)\n", + " #print(test_predict_transf)\n", + " print(test_predict[:30])\n", + " #print(Y_test_transf)\n", + " print(Y_test[:30])\n", "\n", " print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To compare, we do the baseline data (only countries) with the optimal c that we found" ] }, { "cell_type": "code", "execution_count": 185, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error : 44.81\n" ] } ], "source": [ "c = 0.3 # the optimal found above\n", "\n", "# logarithmic transformation\n", "Y_train_transf = np.log(c + Y_train)\n", "Y_test_transf = np.log(c + Y_test)\n", "\n", "# crossvalidation for lambda\n", "reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train_baseline, Y_train_transf)\n", "lambda_ = reg.alpha_\n", "test_predict_transf = reg.predict(X_test_baseline)\n", "# transform the output back\n", "test_predict = np.exp(test_predict_transf) - c\n", "\n", "print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combine multiple models\n", "Even with the logistic transformation, our data doesn't follow a Gaussian distribution. One problem is that there are a lot of samples with value 0. Several papers suggest for situations like that to first perform a classifying task that decides whether a sample is 0 or not, and then apply a second model. (e.g. https://www.kent.ac.uk/smsas/personal/msr/webfiles/zip/ibc_fin.pdf) As we work with count data, Poisson distribution may fit our data better.\\\n", "Hence, we first apply logistic regression to classify into 0 and non 0." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n", - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n", - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n", - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n", - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n", - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n", - "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " n_iter_i = _check_optimize_result(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'C': 1.873817422860385}\n", - "0.7506879471656577\n", - "0.7006053935057788\n", - "accuracy = 0.7991194276279582\n" - ] - } - ], + "outputs": [], "source": [ "Y_train_class = Y_train > 0\n", "Y_test_class = Y_test > 0\n", "\n", "clf = linear_model.LogisticRegression(max_iter=2000, fit_intercept=True)\n", "\n", "# do crossvalidation\n", "params = {'C': np.logspace(-3, 3, 100)}\n", "cv = model_selection.GridSearchCV(clf, params)\n", "cv.fit(X_train_without_exp, Y_train_class)\n", "print(cv.best_params_)\n", "\n", "predict_class = cv.predict(X_test_without_exp)\n", "print(1 - np.mean(predict_class))\n", "print(1 - np.mean(Y_test_class))\n", "accuracy = accuracy_score(Y_test_class, predict_class)\n", "print(f\"accuracy = {accuracy}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Note: only 78.3% accuracy on training data.)\\\n", "Now, on the data that is not zero, we apply a Poission regressor." ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'alpha': 0.002915053062825176}\n", - "Mean squared error on Poisson data : 112.91\n", - "[6.00000000e+00 4.26325641e-14 1.00000000e+00 2.00000000e+00\n", - " 7.00000000e+00 8.00000000e+00 3.00000000e+00 4.50000000e+01\n", - " 1.50000000e+01 1.50000000e+01 4.26325641e-14 3.00000000e+00\n", - " 4.26325641e-14 1.00000000e+00 4.26325641e-14 4.26325641e-14\n", - " 1.50000000e+01 7.00000000e+00 4.26325641e-14 1.00000000e+01]\n", - "[ 8.62270811 1.1837323 8.94528746 17.88349906 2.28719003 5.21037092\n", - " 7.33358919 47.58687484 19.17573901 5.29816849 3.18807448 2.24013399\n", - " 0.79890823 2.18165687 3.60138895 13.15737968 12.73060578 3.67746909\n", - " 1.86193717 2.27919197]\n", - "[False False False False False False True False False False False False\n", - " False False False True False False False False]\n", - "[0. 0. 1. 0. 0. 1. 6. 0. 4. 0. 0. 1. 2. 0. 5. 0. 1. 0. 0. 0.]\n", - "[0. 0. 0. 0. 0. 0.\n", - " 8.62270811 0. 0. 0. 0. 0.\n", - " 0. 0. 0. 1.1837323 0. 0.\n", - " 0. 0. ]\n", - "Mean squared error : 29.68\n" - ] - } - ], + "outputs": [], "source": [ "\"\"\"X_train_poiss = X_train[clf.predict(X_train) == 1]\n", "Y_train_poiss = Y_train[clf.predict(X_train) == 1]\n", "X_test_poiss = X_test[clf.predict(X_test) == 1]\n", "Y_test_poiss = Y_test[clf.predict(X_test) == 1]\"\"\"\n", "X_train_poiss = X_train_without_exp[Y_train_class == 1]\n", "Y_train_poiss = Y_train[Y_train_class == 1]\n", "X_test_poiss = X_test_without_exp[cv.predict(X_test_without_exp) == 1]\n", "Y_test_poiss = Y_test[cv.predict(X_test_without_exp) == 1]\n", "\n", "# log transformation\n", "c = 100 # TODO try others\n", - "Y_train_poiss = np.log(c + Y_train_poiss)\n", - "Y_test_poiss = np.log(c + Y_test_poiss)\n", + "#Y_train_poiss = np.log(c + Y_train_poiss)\n", + "#Y_test_poiss = np.log(c + Y_test_poiss)\n", "\n", "params = {'alpha': np.logspace(-4, 1, 100)}\n", "reg_base = linear_model.PoissonRegressor(max_iter=1000)\n", "reg = model_selection.GridSearchCV(reg_base, params)\n", "# reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train_poiss, Y_train_poiss)\n", "print(reg.best_params_)\n", "predict_poiss = reg.predict(X_test_poiss)\n", "\n", - "predict_poiss = np.exp(predict_poiss) - c\n", + "#predict_poiss = np.exp(predict_poiss) - c\n", "print('Mean squared error on Poisson data : %.2f'\n", " % mean_squared_error(np.exp(Y_test_poiss) - c, predict_poiss))\n", "print((np.exp(Y_test_poiss) - c)[:20])\n", "print(predict_poiss[:20])\n", "\n", "# on everything\n", "predict = np.zeros(Y_test.shape)\n", "print(cv.predict(X_test_without_exp)[:20])\n", "print(Y_test[:20])\n", "predict[cv.predict(X_test_without_exp) == 1] = predict_poiss\n", "print(predict[:20])\n", "\n", "print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models with experience data\n", "We now build the same models but including the data about the experience of participants.\n", "\n", "First we build a trivial linear model." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The best regularizer is lambda = 0.43016357581067904\n" ] } ], "source": [ "# cross validation to determine regularizer\n", "reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train, Y_train)\n", "lambda_ = reg.alpha_\n", "print(f\"The best regularizer is lambda = {lambda_}\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.714562636822776\n", "[-8.42740280e-02 6.25202489e-03 -3.97876267e-01 -2.93231523e-01\n", " -5.14868906e-01 -3.73958354e-01 1.34920496e+00 6.49909570e-01\n", " -5.43697937e-01 -2.73357809e-01 -2.59025366e-03 9.49772406e-02\n", " 6.47481292e-02 -4.74598407e-02 6.65238396e-01 -9.75729925e-01\n", " -1.79266387e+00 -5.02591910e-01 -9.84230361e-01 -1.67043563e+00\n", " -1.52479569e+00 2.69375962e+00 -1.81358927e+00 1.64418397e+01\n", " -1.68908592e+00 -1.57797100e+00 -1.74432568e+00 -1.33684674e+00\n", " 1.20970239e-01 -9.15636708e-01 -3.73511731e-01 -1.88656071e+00\n", " -1.50230637e+00 -1.40210520e+00 -1.41756276e+00 3.75832932e+00\n", " -1.51804693e+00 -1.45012134e+00 8.09619713e+00 -1.27895426e+00\n", " -1.37495790e+00 -1.12157053e+00 -1.53314379e+00 -1.56432065e+00\n", " -1.68975876e+00 -1.54085838e+00 1.30082001e+01 -1.27179105e+00\n", " -1.74863944e+00 -3.35535110e-01 4.55224333e+01 1.63400780e+00\n", " -1.89964038e+00 -1.94438671e+00 -1.52883914e+00 -6.24845984e-01\n", " -2.08660149e+00 -1.24560424e+00 -1.34991799e+00 -1.65852561e+00\n", " -1.69334413e+00 -1.93723036e+00 -1.68543213e+00 -1.77529513e+00\n", " -1.59457133e+00 -1.39535678e+00 -6.86861413e-01 4.09988631e-01\n", " -1.10503374e+00 -1.28852895e+00 -1.71304084e+00 -1.89723502e+00\n", " -1.65168958e+00 -1.45686586e+00 -1.45339301e+00 -1.77245492e+00\n", " -1.16408930e+00 -1.36566613e+00 -7.84256269e-01 -1.63106818e+00\n", " -1.09980692e+00 -4.83846117e-01 -1.87408892e+00 -4.36841844e-01\n", " -1.23769530e+00 -1.92533359e+00 -1.83846185e+00 -1.33700719e+00\n", " -1.60267855e+00 -1.33386012e+00 -1.05383924e+00 -5.69336458e-01\n", " 7.49910832e+00 5.37514672e-01 1.08236937e+00 -1.12133728e+00\n", " -1.77822206e+00 -1.70060854e+00 -1.66026305e+00 -1.36826133e+00\n", " 1.52645837e+01 -1.45017861e+00 -7.90517561e-01 -7.48708533e-01\n", " -1.38161557e+00 6.18930599e-01 -1.40373178e+00 -1.86129296e+00\n", " -1.76046894e+00 -1.54644460e+00 -1.82753573e+00 -1.53318533e+00\n", " -1.28835639e+00 -1.68456315e+00 -1.64964844e+00 -1.68735017e+00\n", " -1.39610882e+00 -1.64728863e+00 -1.48991943e+00 8.12440384e-01\n", " -1.44328630e+00 -1.49110910e+00 -1.60726979e+00 9.68081106e-01\n", " -1.16774895e+00 -1.05966482e+00 1.04529681e+00 -4.24174998e-01\n", " -1.84058382e+00 -1.74640645e+00 -1.70424572e+00 -1.35619692e+00\n", " -1.77555080e+00 -1.66107081e+00 -1.46161069e+00 -1.45756560e+00\n", " -1.42398724e+00 -1.14040693e+00 -1.02721690e+00 7.78278659e+00\n", " -6.42989395e-01 -1.94400679e+00 1.50227193e-02 -1.62497935e+00\n", " -1.36983621e+00 7.74937266e+00 -1.06304981e+00 -4.56532709e-02\n", " -1.44252989e+00 -5.71087078e-01 -1.19374392e+00 9.20164317e-01\n", " -1.34127353e+00 -1.57390571e-01 2.94659308e+00 -8.78456301e-01\n", " -1.87775518e+00 -9.89449966e-01 -1.56398070e+00 5.20976872e+00\n", " -1.64314723e+00 6.16255646e-02 -1.64255155e+00 -1.72854832e+00\n", " 1.31645249e+01 -8.46271225e-01 -1.50611094e+00 -1.69595500e+00\n", " -1.31357816e+00 -3.99542586e-01 -1.75441370e+00 -1.30883835e+00\n", " -1.59897455e+00 -9.44741237e-01 5.04731873e+00 4.76013911e-01\n", " -8.95775573e-01 -1.26473470e+00 -1.38225356e+00 -1.71136052e+00\n", " -1.44632041e+00 -1.61973937e+00 -1.13417767e+00 -1.46969202e+00\n", " -1.59238051e+00 8.55055036e+00 -1.28451181e+00 -1.38566326e+00\n", " -2.36413389e-01 -7.84661690e-01 -1.04929789e+00 -1.76424387e+00\n", " -1.47501580e+00 -1.16792986e+00 -1.83214956e+00 -6.15581520e-01\n", " -1.72389105e+00 3.83624284e+00 -4.37055990e-01 -8.05683577e-01\n", " -1.04551526e+00 -2.11674385e+00 5.29749490e+01 -4.52297529e-01\n", " -1.69245118e+00 -1.71756198e+00 -1.63084216e+00 1.30245117e+00\n", " -1.59072302e+00 -1.62281991e+00 -1.13137294e+00 -8.85130642e-01\n", " -2.03543256e+00 -2.17923487e+00]\n" ] } ], "source": [ "reg = linear_model.Ridge(alpha=lambda_)\n", "reg.fit(X_train, Y_train)\n", "w0 = reg.intercept_\n", "W = reg.coef_\n", "print(w0)\n", "print(W)\n", "np.savetxt(\"W_withexp.csv\", W, delimiter=\",\")" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1.18297512e+00 -7.42407867e-01 1.65214044e+00 -2.46034308e-02\n", " -4.69021814e-01 8.68726653e-01 6.72104627e+00 8.29711110e-01\n", " 1.30591417e+00 3.87780478e-01 2.73052330e-02 -8.52067029e-01\n", " 2.42915316e+00 6.22925442e-03 9.82701624e-01 1.85688847e+00\n", " 5.14099547e-01 6.61749796e-01 -1.34684735e-01 -1.17630981e-01]\n", "[0. 0. 1. 0. 0. 1. 6. 0. 4. 0. 0. 1. 2. 0. 5. 0. 1. 0. 0. 0.]\n", "Mean squared error: 29.32\n" ] } ], "source": [ "test_predict = reg.predict(X_test)\n", "print(test_predict[:20])\n", "print(Y_test[:20])\n", "print('Mean squared error: %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Second, we do the linear model with a logistic transformation." ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "c = 10\n", "Lambda: 0.3544455673970436\n", "[2.36868331 2.27255708 2.41760197 ... 2.97955573 2.31652432 3.87262415]\n", "[2.30258509 2.30258509 2.39789527 ... 2.30258509 2.30258509 3.8501476 ]\n", "[ 0.68331637 -0.29581649 1.21892376 ... 9.67907192 0.14036829\n", " 38.06835913]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 34.32\n", "c = 100\n", "Lambda: 0.3643858983763548\n", "[4.61541371 4.59918935 4.62094193 ... 4.70766727 4.60605685 4.97402359]\n", "[4.60517019 4.60517019 4.61512052 ... 4.60517019 4.60517019 4.91998093]\n", "[ 1.02961728 -0.59629879 1.58967752 ... 10.79340722 0.08870552\n", " 44.60755927]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 30.59\n", "c = 500\n", "Lambda: 0.4070142453219439\n", "[6.2168895 6.21320115 6.21788416 ... 6.23641285 6.21472897 6.30389222]\n", "[6.2146081 6.2146081 6.2166061 ... 6.2146081 6.2146081 6.28599809]\n", "[ 1.14200538 -0.70298117 1.64071919 ... 11.02210731 0.06044109\n", " 46.69563152]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 29.64\n", "c = 1000\n", "Lambda: 0.41842885079015846\n", "[6.90891612 6.90703328 6.90940047 ... 6.91874894 6.90781098 6.95374095]\n", "[6.90775528 6.90775528 6.90875478 ... 6.90775528 6.90775528 6.94408721]\n", "[ 1.16151439 -0.72173879 1.64655001 ... 11.05431636 0.05569874\n", " 47.0594116 ]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 29.49\n" ] } ], "source": [ "cs = [10, 100, 500, 1000]\n", "\n", "for c in cs:\n", " print(f\"c = {c}\")\n", " # logarithmic transformation\n", " Y_train_transf = np.log(c + Y_train)\n", " Y_test_transf = np.log(c + Y_test)\n", " \n", " # crossvalidation for lambda\n", " reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", " reg.fit(X_train, Y_train_transf)\n", " print(f\"Lambda: {reg.alpha_}\")\n", " test_predict_transf = reg.predict(X_test)\n", " # transform the output back\n", " test_predict = np.exp(test_predict_transf) - c\n", " \n", " print(test_predict_transf)\n", " print(Y_test_transf)\n", " print(test_predict)\n", " print(Y_test)\n", " \n", " print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Last, we apply the two step model with Logistic Regression and the Poisson Regressor." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 } diff --git a/report/images/distr_interventions.png b/report/images/distr_interventions.png new file mode 100644 index 0000000..7c3f410 Binary files /dev/null and b/report/images/distr_interventions.png differ diff --git a/report/images/experiencescore_overview.png b/report/images/experiencescore_overview.png new file mode 100644 index 0000000..c0de730 Binary files /dev/null and b/report/images/experiencescore_overview.png differ diff --git a/report/images/ff_cop.png b/report/images/ff_cop.png new file mode 100644 index 0000000..61ef6a3 Binary files /dev/null and b/report/images/ff_cop.png differ diff --git a/report/images/ff_sb.png b/report/images/ff_sb.png new file mode 100644 index 0000000..04c5bb0 Binary files /dev/null and b/report/images/ff_sb.png differ diff --git a/report/images/participant_flow_maxdegree_allafs.png b/report/images/participant_flow_maxdegree_allafs.png new file mode 100644 index 0000000..cb360da Binary files /dev/null and b/report/images/participant_flow_maxdegree_allafs.png differ diff --git a/report/images/roles_cop.png b/report/images/roles_cop.png new file mode 100644 index 0000000..c3403f2 Binary files /dev/null and b/report/images/roles_cop.png differ diff --git a/report/images/roles_sb.png b/report/images/roles_sb.png new file mode 100644 index 0000000..e42d73b Binary files /dev/null and b/report/images/roles_sb.png differ