diff --git a/code/data/dictionaries/valid_countries.txt b/code/data/dictionaries/valid_countries.txt new file mode 100644 index 0000000..2deff92 --- /dev/null +++ b/code/data/dictionaries/valid_countries.txt @@ -0,0 +1,198 @@ +Bulgaria +Guinea-Bissau +Mauritius +Liechtenstein +Brazil +Central African Republic +Burkina Faso +Denmark +New Zealand +Nigeria +Andorra +Serbia +Comoros +Laos +Algeria +Bhutan +Sao Tome and Principe +Antigua and Barbuda +Yemen +South Korea +Bangladesh +Maldives +Cambodia +Mozambique +Austria +Brunei Darussalam +United States +Rwanda +Moldova +Greece +El Salvador +Norway +Marshall Islands +Suriname +Ethiopia +Trinidad and Tobago +Congo Republic +Equatorial Guinea +Costa Rica +Malta +Nauru +Hungary +Canada +Cameroon +Namibia +Libya +Italy +Thailand +Ghana +Ukraine +Zambia +Kuwait +Switzerland +Kiribati +Japan +Eritrea +Kyrgyz Republic +Russia +Haiti +Togo +Colombia +Ecuador +Albania +Slovenia +Portugal +Lithuania +Madagascar +Lebanon +Poland +United Kingdom +Cote d'Ivoire +Guinea +Bahamas +Tunisia +Peru +Honduras +Monaco +Egypt +France +Czech Republic +Latvia +Jordan +Argentina +Mauritania +Zimbabwe +Botswana +Luxembourg +Papua New Guinea +Vanuatu +Turkey +San Marino +Solomon Islands +Niue +Tuvalu +Croatia +Dominica +Spain +Uruguay +Eswatini +Bahrain +Liberia +Chile +Morocco +Benin +Malawi +St. Vincent and the Grenadines +Estonia +Iceland +Macedonia +DR Congo +Australia +Bolivia +Romania +Mali +St. Lucia +Sierra Leone +Belgium +Afghanistan +Gambia +Azerbaijan +Uzbekistan +Cabo Verde +Israel +South Africa +Seychelles +Palestine +Indonesia +Iraq +Timor-Leste +Kenya +Micronesia, Fed. Sts. +Nepal +Cook Islands +Paraguay +Finland +United Arab Emirates +Belarus +China +Myanmar +Niger +Tanzania +Tajikistan +Sudan +Grenada +South Sudan +Germany +Uganda +Iran +Burundi +Somalia +Qatar +Venezuela +Syria +Mongolia +Jamaica +Malaysia +Cuba +Armenia +Lesotho +Gabon +Palau +Oman +Fiji +Montenegro +Turkmenistan +Belize +Senegal +Guatemala +Sweden +Kazakhstan +Vatican +Pakistan +Slovakia +North Korea +Bosnia and Herzegovina +Tonga +Angola +Saudi Arabia +Panama +Djibouti +Barbados +Mexico +Netherlands +Guyana +India +Ireland +Dominican Republic +Samoa +Georgia +Chad +Vietnam +Philippines +St. Kitts and Nevis +Cyprus +Singapore +Nicaragua +Sri Lanka +european union diff --git a/code/scripts/plots/plot_participant_graph.py b/code/scripts/plots/plot_participant_graph.py index 8043702..87b08da 100644 --- a/code/scripts/plots/plot_participant_graph.py +++ b/code/scripts/plots/plot_participant_graph.py @@ -1,84 +1,97 @@ import pandas as pd import matplotlib.pyplot as plt import json import networkx as nx +def find_largest_parties(): + country_file = open("../data/dictionaries/valid_countries.txt", "r") + countries = country_file.readlines() + countries = [c.replace("\n", "") for c in countries] + + complete_data = pd.read_csv("../results/complete_dataset.csv", + encoding="utf-8-sig") + parties = complete_data.loc[complete_data["affiliation_category"] == "parties"] + parties = parties.loc[parties["affiliation"].apply(lambda x: x in countries)] + + total_nb_participants_per_country = dict() + grouped_parties = parties.groupby("affiliation") + for aff, people in grouped_parties: + total_nb_participants_per_country[aff] = len(people) + + sorted_c = sorted(total_nb_participants_per_country.items(), key=lambda x: x[1], reverse=True) + print(sorted_c) + return [x[0] for x in sorted_c] + def plot(path): LABEL_IDX = 0 NAME_IDX = 1 AFFILIATION_IDX = 2 CATEGORY_IDX = 3 f = open(path, "r", encoding="utf-8") text = f.read() names = json.loads(text) - print(len(names)) # exclude the names that have an error (two names in the same meeting) names = {n: l for n, l in names.items() if len(set([m[0] for m in l])) == len(l)} - print(len(names)) country_file = open("../data/dictionaries/valid_countries.txt", "r") countries = country_file.readlines() countries = [c.replace("\n", "") for c in countries] - """G = nx.MultiGraph() - G.add_node("Germany") - G.add_node("Switzerland") - G.add_node("France") - - G.add_edge("Germany", "Switzerland", weight=0.001) - G.add_edge("Germany", "France", weight=1)""" + max_set_n = 40 + biggest_countries = find_largest_parties()[:max_set_n] + # biggest_countries.append("european union") G = nx.Graph() G.clear() - affiliations = set() # TODO could just do it for the 40 countries that have the most participants + affiliations = set(biggest_countries) # TODO could just do it for the 40 countries that have the most participants + G.add_nodes_from(biggest_countries) # TODO maybe add NGO's - max_set_n = 40 for name, list in names.items(): previous_affiliation = "" current_affiliation = "" for participation in list: if participation[AFFILIATION_IDX] in countries: previous_affiliation = current_affiliation current_affiliation = participation[AFFILIATION_IDX] if current_affiliation not in affiliations and len(affiliations) < max_set_n: print(current_affiliation) G.add_node(current_affiliation) affiliations.add(current_affiliation) if previous_affiliation in affiliations and current_affiliation in affiliations and previous_affiliation != "" and previous_affiliation != current_affiliation: if (previous_affiliation, current_affiliation) in G.edges: # increase weight if G[previous_affiliation][current_affiliation]["weight"] > 20: print(name) print(list) G[previous_affiliation][current_affiliation]["weight"] += 1 else: G.add_edge(previous_affiliation, current_affiliation, weight=1) # find the largest weight for the resizing of the egdes sorted_edges = sorted(G.edges(data=True), key=lambda x: x[2]['weight'], reverse=True) max_weight = sorted_edges[0][2]["weight"] print(max_weight) pos = nx.circular_layout(G) # nx.draw_networkx(G, pos) nx.draw_networkx_nodes( G, pos, node_size=2, ) nx.draw_networkx_labels(G, pos, font_color="red", font_weight="bold") for edge in G.edges(data='weight'): nx.draw_networkx_edges(G, pos, edgelist=[edge], width=edge[2]/max_weight*10) """plt.subplot(122) nx.draw_shell(G, nlist=[range(5, 10), range(5)], with_labels=True, font_weight='bold')""" """print(G.nodes()) nx.draw(G)""" """plt.xlim(-0.05, 1.05) plt.ylim(-0.05, 1.05)""" plt.axis("off") plt.show() \ No newline at end of file diff --git a/code/scripts/predict_interventions.ipynb b/code/scripts/predict_interventions.ipynb index f52604e..2f22c01 100644 --- a/code/scripts/predict_interventions.ipynb +++ b/code/scripts/predict_interventions.ipynb @@ -1,977 +1,977 @@ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predictive modelling of interventions\n", "## Ridge regression\n", "\n", "This notebook is ................\n", "\n", "\n", "First of all, we import the necessary packets." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn import linear_model\n", "from sklearn import model_selection\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.metrics import accuracy_score\n", "\n", "# constants\n", "FIRST_COUNTRY_INDEX_WO_EXP = 12\n", "FIRST_COUNTRY_INDEX = FIRST_COUNTRY_INDEX_WO_EXP + 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Prepare the dataset\n", "The data is provided in pandas dataframes. If the necessary csv files are not available, they can be generated with the script 'prepare_intervention_data.py'. This data now needs to be converted into numpy array such that we can train our model." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9217\n", "['Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola', 'Antigua and Barbuda', 'Argentina', 'Armenia', 'Australia', 'Austria', 'Azerbaijan', 'Bahamas', 'Bahrain', 'Bangladesh', 'Barbados', 'Belarus', 'Belgium', 'Belize', 'Benin', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'Brunei Darussalam', 'Bulgaria', 'Burkina Faso', 'Burundi', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada', 'Central African Republic', 'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo Republic', 'Cook Islands', 'Costa Rica', \"Cote d'Ivoire\", 'Croatia', 'Cuba', 'Cyprus', 'Czech Republic', 'DR Congo', 'Denmark', 'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea', 'Estonia', 'Eswatini', 'Ethiopia', 'Fiji', 'Finland', 'France', 'Gabon', 'Gambia', 'Georgia', 'Germany', 'Ghana', 'Greece', 'Grenada', 'Guatemala', 'Guinea', 'Guinea-Bissau', 'Guyana', 'Haiti', 'Honduras', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Iran', 'Iraq', 'Ireland', 'Israel', 'Italy', 'Jamaica', 'Japan', 'Jordan', 'Kazakhstan', 'Kenya', 'Kiribati', 'Kuwait', 'Kyrgyz Republic', 'Laos', 'Latvia', 'Lebanon', 'Lesotho', 'Liberia', 'Libya', 'Liechtenstein', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malawi', 'Malaysia', 'Maldives', 'Mali', 'Malta', 'Marshall Islands', 'Mauritania', 'Mauritius', 'Mexico', 'Micronesia, Fed. Sts.', 'Moldova', 'Monaco', 'Mongolia', 'Montenegro', 'Morocco', 'Mozambique', 'Myanmar', 'Namibia', 'Nauru', 'Nepal', 'Netherlands', 'New Zealand', 'Nicaragua', 'Niger', 'Nigeria', 'Niue', 'North Korea', 'Norway', 'Oman', 'Pakistan', 'Palau', 'Palestine', 'Panama', 'Papua New Guinea', 'Paraguay', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Qatar', 'Romania', 'Russia', 'Rwanda', 'Samoa', 'San Marino', 'Sao Tome and Principe', 'Saudi Arabia', 'Senegal', 'Serbia', 'Seychelles', 'Sierra Leone', 'Singapore', 'Slovakia', 'Slovenia', 'Solomon Islands', 'Somalia', 'South Africa', 'South Korea', 'South Sudan', 'Spain', 'Sri Lanka', 'St. Kitts and Nevis', 'St. Lucia', 'St. Vincent and the Grenadines', 'Sudan', 'Suriname', 'Sweden', 'Switzerland', 'Syria', 'Tajikistan', 'Tanzania', 'Thailand', 'Timor-Leste', 'Togo', 'Tonga', 'Trinidad and Tobago', 'Tunisia', 'Turkey', 'Turkmenistan', 'Tuvalu', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay', 'Uzbekistan', 'Vanuatu', 'Vatican', 'Venezuela', 'Vietnam', 'Yemen', 'Zambia', 'Zimbabwe']\n" ] } ], "source": [ "data = pd.read_csv(\"../data/data_regression/dataset_interventions.csv\",\n", " encoding=\"utf-8-sig\")\n", "\n", - "D = 213\n", + "D = 214\n", "D_wo_exp = D - 3\n", "N = len(data)\n", "print(N)\n", "dataset = np.zeros((N, D), dtype=np.float64)\n", "dataset_without_exp = np.zeros((N, D_wo_exp), dtype=np.float64)\n", "\n", "dataset_without_exp[:,:FIRST_COUNTRY_INDEX_WO_EXP] = (data.loc[:,\"year\":\"woman_proportion\"]).to_numpy()\n", "dataset[:,:FIRST_COUNTRY_INDEX] = (data.loc[:,\"year\":\"experience score parties rate\"]).to_numpy()\n", "\n", "labelset = np.zeros((N,), dtype=np.float64)\n", "labelset[:] = (data.loc[:, \"nb_interventions\"]).to_numpy()\n", "\n", "# read the valid countries\n", "country_file = open(\"../data/dictionaries/valid_countries.txt\", \"r\")\n", "countries = country_file.readlines()\n", "countries = [c.replace(\"\\n\", \"\") for c in countries]\n", "countries = sorted(countries)\n", "print(countries)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The resting part of the dataset is the country. We need to write a function that returns the index of the country in a sorted list of the valid countries." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def get_country_index(country):\n", " if country in countries:\n", " return countries.index(country)\n", " else:\n", " # unknown country\n", " return len(countries)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To finalize the dataset, we use the defined function. for indices 12 to 209, a 1 means that the affiliation is this country, 0 means that it isn't." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "for i in range(N):\n", " country = data.iloc[i,1]\n", " dataset_without_exp[i, FIRST_COUNTRY_INDEX_WO_EXP + get_country_index(country)] = 1\n", " dataset[i, FIRST_COUNTRY_INDEX + get_country_index(country)] = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Split the data to training data and test data\n", "In a first step, I consider the first 80% of the samples as training data and the resting 20% as test data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# shuffle everything\n", "np.random.seed(2020)\n", "p = np.random.permutation(N)\n", "shuffled_dataset_wo_exp = dataset_without_exp[p]\n", "shuffled_dataset = dataset[p]\n", "shuffled_labelset = labelset[p]\n", "\n", "SPLIT_IDX = 7400\n", "# seperate train and test data\n", "X_train_without_exp = shuffled_dataset_wo_exp[:SPLIT_IDX]\n", "X_train = shuffled_dataset[:SPLIT_IDX]\n", "Y_train = shuffled_labelset[:SPLIT_IDX]\n", "\n", "X_test_without_exp = shuffled_dataset_wo_exp[SPLIT_IDX:]\n", "X_test = shuffled_dataset[SPLIT_IDX:]\n", "Y_test = shuffled_labelset[SPLIT_IDX:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Baseline models\n", "To have a upper bound for the performance of our models, we introduce two baseline models. The first one consists in predicting always 0 interventions." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 91.79\n" ] } ], "source": [ "n = Y_test.shape\n", "test_predict_baseline_zero = np.zeros(n)\n", "baseline_zero_mse = mean_squared_error(Y_test, test_predict_baseline_zero)\n", "print('Mean squared error: %.2f'\n", " % baseline_zero_mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another baseline is to predict for each country the average number of interventions done in the training data meetings." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 29.33\n" ] } ], "source": [ "test_predict_baseline_avg = np.zeros(n)\n", "\n", "# fill in a list with the averages\n", "intervention_averages = []\n", " \n", "# predict accordingly\n", "for i in range(len(countries)):\n", " index = FIRST_COUNTRY_INDEX + i\n", " train_samples_this_country = (Y_train[X_train[:,index] == 1])\n", " avg = np.sum(train_samples_this_country) / len(train_samples_this_country)\n", " intervention_averages.append(avg)\n", " \n", " test_predict_baseline_avg[X_test[:,index] == 1] = avg\n", " \n", "baseline_avg_mse = mean_squared_error(Y_test, test_predict_baseline_avg)\n", "print('Mean squared error: %.2f'\n", " % baseline_avg_mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, this result is equal to a simple linear model without global bias that only works on the country data." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 29.33\n" ] } ], "source": [ "# baseline data\n", "X_train_baseline = X_train[:,FIRST_COUNTRY_INDEX:]\n", "X_test_baseline = X_test[:,FIRST_COUNTRY_INDEX:]\n", "\n", "reg = linear_model.LinearRegression(fit_intercept=False)\n", "reg.fit(X_train_baseline, Y_train)\n", "test_predict = reg.predict(X_test_baseline)\n", "intervention_mse = mean_squared_error(Y_test, test_predict)\n", "print('Mean squared error: %.2f'\n", " % intervention_mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Ridge regression on all the data without experience\n", "Now, we can train the actual first model, conventional ridge regression that expects a gaussion distribution. We first use crossvalidation to determine the best regularizer." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The best regularizer is lambda = 0.41842885079015846\n" ] } ], "source": [ "# cross validation to determine regularizer\n", "reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train_without_exp, Y_train)\n", "lambda_ = reg.alpha_\n", "print(f\"The best regularizer is lambda = {lambda_}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we train the model with the optimal lambda." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-7.13642429e-02 6.58946776e-03 -3.82433989e-01 -2.49429907e-01\n", " -5.96414687e-01 -3.34486906e-01 1.39375847e+00 6.47823435e-01\n", " -6.05626958e-01 -2.55623447e-01 -2.60952679e-03 3.95918137e-02\n", " -1.21377146e+00 -1.81767521e+00 -4.71824154e-01 -1.19961948e+00\n", " -1.65663683e+00 -1.44488124e+00 2.72858263e+00 -1.69571753e+00\n", " 1.64524129e+01 -1.58900153e+00 -1.72537838e+00 -1.62797160e+00\n", " -1.53844956e+00 -3.42891263e-02 -9.00781289e-01 -4.59202070e-01\n", " -1.77278387e+00 -1.48543371e+00 -1.34299708e+00 -1.35349655e+00\n", " 3.70911272e+00 -1.46326714e+00 -1.40883620e+00 8.23009830e+00\n", " -1.35422893e+00 -1.38508140e+00 -1.11907158e+00 -1.53542728e+00\n", " -1.62580022e+00 -1.65087519e+00 -1.60252097e+00 1.30240996e+01\n", " -1.31650172e+00 -1.65965215e+00 -3.17488981e-01 4.56621993e+01\n", " 1.62312763e+00 -1.89724934e+00 -1.89948799e+00 -1.51666065e+00\n", " -5.98840741e-01 -1.99762987e+00 -1.22642821e+00 -1.30885382e+00\n", " -1.54629655e+00 -1.63875425e+00 -1.91495061e+00 -1.62549591e+00\n", " -1.74153932e+00 -1.66840866e+00 -1.39982960e+00 -7.75761703e-01\n", " 3.71759205e-01 -1.18158657e+00 -1.45139592e+00 -1.82436479e+00\n", " -1.83757737e+00 -1.63917996e+00 -1.50203893e+00 -1.51269781e+00\n", " -1.63591116e+00 -1.04561859e+00 -1.38372966e+00 -7.80581567e-01\n", " -1.62876818e+00 -1.05448387e+00 -4.81512312e-01 -1.81114302e+00\n", " -5.04974347e-01 -1.17298904e+00 -1.88757119e+00 -1.79047334e+00\n", " -1.52190023e+00 -1.60448353e+00 -1.40594681e+00 -1.06750718e+00\n", " -5.28399045e-01 7.49276801e+00 6.30359904e-01 1.04433777e+00\n", " -1.23543153e+00 -1.71903051e+00 -1.72589773e+00 -1.53173339e+00\n", " -1.40181466e+00 1.52854831e+01 -1.50459493e+00 -8.54686597e-01\n", " -7.26223162e-01 -1.44016893e+00 6.44598330e-01 -1.44790264e+00\n", " -1.80902728e+00 -1.73040640e+00 -1.55656881e+00 -1.83424344e+00\n", " -1.58254783e+00 -1.38620048e+00 -1.57015830e+00 -1.66656551e+00\n", " -1.62861822e+00 -1.55519089e+00 -1.64780335e+00 -1.53947034e+00\n", " 8.31876980e-01 -1.38660348e+00 -1.43404338e+00 -1.64099921e+00\n", " 9.43863715e-01 -1.17462779e+00 -1.06151389e+00 1.09378073e+00\n", " -3.50150316e-01 -1.74820156e+00 -1.74724953e+00 -1.82254524e+00\n", " -1.45728350e+00 -1.72301180e+00 -1.66051204e+00 -1.71446044e+00\n", " -1.40461973e+00 -1.44136632e+00 -1.24447239e+00 -1.02824441e+00\n", " 7.82166535e+00 -6.88488114e-01 -1.89775217e+00 1.10945218e-01\n", " -1.69381285e+00 -1.67813375e+00 7.84567195e+00 -1.12071394e+00\n", " -1.46714631e-01 -1.55216505e+00 -9.39937768e-01 -1.16002852e+00\n", " 9.00907200e-01 -1.41621122e+00 -1.87597216e-01 2.96127915e+00\n", " -8.19373997e-01 -1.82536960e+00 -1.03491066e+00 -1.54119289e+00\n", " 5.21971752e+00 -1.61707535e+00 9.64571438e-02 -1.85697260e+00\n", " -1.69489468e+00 1.32126052e+01 -7.17685890e-01 -1.51187939e+00\n", " -1.66654440e+00 -1.36268284e+00 -3.65299724e-01 -1.72874729e+00\n", " -1.28401247e+00 -1.53892082e+00 -1.37462923e+00 5.01072537e+00\n", " 5.05612917e-01 -1.12228025e+00 -1.21343874e+00 -1.52425989e+00\n", " -1.63414787e+00 -1.43100126e+00 -1.68297071e+00 -1.15365286e+00\n", " -1.61481048e+00 -1.50666642e+00 8.64789477e+00 -1.45831372e+00\n", " -1.39809723e+00 -2.32839419e-01 -7.15218490e-01 -1.07007246e+00\n", " -1.76416753e+00 -1.65434178e+00 -1.12325176e+00 -1.72949626e+00\n", " -7.59980720e-01 -1.77653753e+00 3.90534733e+00 -3.46604501e-01\n", " -8.46975381e-01 -1.08871290e+00 -2.05608339e+00 5.30970083e+01\n", " -3.93592744e-01 -1.69549868e+00 -1.77047873e+00 -1.88737188e+00\n", " 1.32449924e+00 -1.55660831e+00 -1.68399117e+00 -1.15687968e+00\n", " -7.88748835e-01 -2.25608474e+00]\n" ] } ], "source": [ "reg = linear_model.Ridge(alpha=lambda_)\n", "reg.fit(X_train_without_exp, Y_train)\n", "w0 = reg.intercept_\n", "W = reg.coef_\n", "#print(w0)\n", "print(W)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error: 29.27\n" ] } ], "source": [ "test_predict = reg.predict(X_test_without_exp)\n", "print('Mean squared error: %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we notice, the mean square error is basically equal to the one in the baseline model. The additional dimensions do not help to improve the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Linear Model with logarithmic transformation\n", "The reason for the failure of the normal linear model is that our Y doesn't follow a Gaussian distribution, but a logarithmic. We thus try the same model but with a transformed Y. We transform the y's accordingly: \n", "$y' = log(c + y)$ with $c > 0$.\n", "We try different values for c. First, c = 1 (which preserves y = 0 to be y' = 0)" ] }, { "cell_type": "code", "execution_count": 200, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "c = 10\n", "Lambda is 0.3447764054734464\n", "[2.36003016 2.29673558 2.41597337 ... 2.98163425 2.30944377 3.87097441]\n", "[ 0.59127089 -0.05832438 1.20066747 ... 9.72001773 0.06882252\n", " 37.98912433]\n", "[2.30258509 2.30258509 2.39789527 ... 2.30258509 2.30258509 3.8501476 ]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "Mean squared error : 34.31\n", "c = 1000\n", "Lambda is 0.4070142453219439\n", "[6.90866435 6.90759974 6.90938658 ... 6.91880515 6.90767333 6.9534976 ]\n", "[ 0.9094879 -0.15552457 1.63262819 ... 11.11114723 -0.08194264\n", " 46.80463149]\n", "[6.90775528 6.90775528 6.90875478 ... 6.90775528 6.90775528 6.94408721]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "Mean squared error : 29.44\n", "c = 2000\n", "Lambda is 0.4070142453219439\n", "[7.60136086 7.60082379 7.60172055 ... 7.606452 7.60085931 7.62413461]\n", "[ 0.91701708 -0.15734237 1.63685623 ... 11.12993688 -0.08628893\n", " 47.00824055]\n", "[7.60090246 7.60090246 7.60140233 ... 7.60090246 7.60090246 7.61923342]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "Mean squared error : 29.35\n", "c = 5000\n", "Lambda is 0.41842885079015846\n", "[8.51737757 8.51716146 8.51752098 ... 8.51941864 8.51717549 8.52657378]\n", "[ 0.92196911 -0.15863347 1.63922619 ... 11.13962064 -0.08850618\n", " 47.12363259]\n", "[8.51719319 8.51719319 8.51739317 ... 8.51719319 8.51719319 8.52456595]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "Mean squared error : 29.30\n", "c = 10000\n", "Lambda is 0.41842885079015846\n", "[9.21043273 9.21032447 9.21050437 ... 9.2114541 9.21033143 9.215046 ]\n", "[ 0.92360649 -0.15901823 1.64008583 ... 11.14348634 -0.08942983\n", " 47.16715676]\n", "[9.21034037 9.21034037 9.21044037 ... 9.21034037 9.21034037 9.21403354]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "Mean squared error : 29.29\n", "c = 20000\n", "Lambda is 0.41842885079015846\n", "[9.90353377 9.90347959 9.90356958 ... 9.90404467 9.90348306 9.90584423]\n", "[ 0.92443436 -0.15921203 1.64051656 ... 11.14542617 -0.08989534\n", " 47.18910678]\n", "[9.90348755 9.90348755 9.90353755 ... 9.90348755 9.90348755 9.90533584]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "Mean squared error : 29.28\n" ] } ], "source": [ "cs = [10, 1000, 2000, 5000, 10000, 20000]\n", "\n", "for c in cs:\n", " print(f\"c = {c}\")\n", " # logarithmic transformation\n", " Y_train_transf = np.log(c + Y_train)\n", " Y_test_transf = np.log(c + Y_test)\n", "\n", " # crossvalidation for lambda\n", " reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", " reg.fit(X_train_without_exp, Y_train_transf)\n", " lambda_ = reg.alpha_\n", " print(f\"Lambda is {lambda_}\")\n", " test_predict_transf = reg.predict(X_test_without_exp)\n", " # transform the output back\n", " test_predict = np.exp(test_predict_transf) - c\n", " \n", " print(test_predict_transf)\n", " print(test_predict)\n", " print(Y_test_transf)\n", " print(Y_test)\n", "\n", " print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To compare, we do the baseline data (only countries) with the optimal c that we found" ] }, { "cell_type": "code", "execution_count": 185, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean squared error : 44.81\n" ] } ], "source": [ "c = 0.3 # the optimal found above\n", "\n", "# logarithmic transformation\n", "Y_train_transf = np.log(c + Y_train)\n", "Y_test_transf = np.log(c + Y_test)\n", "\n", "# crossvalidation for lambda\n", "reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train_baseline, Y_train_transf)\n", "lambda_ = reg.alpha_\n", "test_predict_transf = reg.predict(X_test_baseline)\n", "# transform the output back\n", "test_predict = np.exp(test_predict_transf) - c\n", "\n", "print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combine multiple models\n", "Even with the logistic transformation, our data doesn't follow a Gaussian distribution. One problem is that there are a lot of samples with value 0. Several papers suggest for situations like that to first perform a classifying task that decides whether a sample is 0 or not, and then apply a second model. (e.g. https://www.kent.ac.uk/smsas/personal/msr/webfiles/zip/ibc_fin.pdf) As we work with count data, Poisson distribution may fit our data better.\\\n", "Hence, we first apply logistic regression to classify into 0 and non 0." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "C:\\ProgramData\\Miniconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n" ] } ], "source": [ "Y_train_class = Y_train > 0\n", "Y_test_class = Y_test > 0\n", "\n", "clf = linear_model.LogisticRegression(max_iter=2000, fit_intercept=True)\n", "\n", "# do crossvalidation\n", "params = {'C': np.logspace(-3, 3, 100)}\n", "cv = model_selection.GridSearchCV(clf, params)\n", "cv.fit(X_train_without_exp, Y_train_class)\n", "print(cv.best_params_)\n", "\n", "predict_class = cv.predict(X_test_without_exp)\n", "print(1 - np.mean(predict_class))\n", "print(1 - np.mean(Y_test_class))\n", "accuracy = accuracy_score(Y_test_class, predict_class)\n", "print(f\"accuracy = {accuracy}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Note: only 78.3% accuracy on training data.)\\\n", "Now, on the data that is not zero, we apply a Poission regressor." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\"\"\"X_train_poiss = X_train[clf.predict(X_train) == 1]\n", "Y_train_poiss = Y_train[clf.predict(X_train) == 1]\n", "X_test_poiss = X_test[clf.predict(X_test) == 1]\n", "Y_test_poiss = Y_test[clf.predict(X_test) == 1]\"\"\"\n", "X_train_poiss = X_train_without_exp[Y_train_class == 1]\n", "Y_train_poiss = Y_train[Y_train_class == 1]\n", "X_test_poiss = X_test[cv.predict(X_test) == 1]\n", "Y_test_poiss = Y_test[cv.predict(X_test) == 1]\n", "\n", "# log transformation\n", "c = 100 # TODO try others\n", "Y_train_poiss = np.log(c + Y_train_poiss)\n", "Y_test_poiss = np.log(c + Y_test_poiss)\n", "\n", "params = {'alpha': np.logspace(-4, 1, 100)}\n", "reg_base = linear_model.PoissonRegressor(max_iter=1000)\n", "reg = model_selection.GridSearchCV(reg_base, params)\n", "# reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", "reg.fit(X_train_poiss, Y_train_poiss)\n", "print(reg.best_params_)\n", "predict_poiss = reg.predict(X_test_poiss)\n", "\n", "predict_poiss = np.exp(predict_poiss) - c\n", "print('Mean squared error on Poisson data : %.2f'\n", " % mean_squared_error(np.exp(Y_test_poiss) - c, predict_poiss))\n", "print((np.exp(Y_test_poiss) - c)[:20])\n", "print(predict_poiss[:20])\n", "\n", "# on everything\n", "predict = np.zeros(Y_test.shape)\n", "print(cv.predict(X_test_without_exp)[:20])\n", "print(Y_test[:20])\n", "predict[cv.predict(X_test_without_exp) == 1] = predict_poiss\n", "print(predict[:20])\n", "\n", "print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models with experience data\n", "We now build the same models but including the data about the experience of participants.\n", "\n", "First we build a trivial linear model." ] }, { "cell_type": "code", "execution_count": 186, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The best regularizer is lambda = 0.4310855408791511\n" ] } ], "source": [ "# cross validation to determine regularizer\n", "reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 10000))\n", "reg.fit(X_train, Y_train)\n", "lambda_ = reg.alpha_\n", "print(f\"The best regularizer is lambda = {lambda_}\")" ] }, { "cell_type": "code", "execution_count": 187, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.722046481720346\n", "[-8.43473593e-02 6.26960317e-03 -3.98085674e-01 -2.90482130e-01\n", " -5.12914835e-01 -3.71611114e-01 1.34242274e+00 6.49170731e-01\n", " -5.44319820e-01 -2.72265577e-01 -2.89018243e-03 9.55562443e-02\n", " 6.49635158e-02 -4.68683931e-02 6.65223015e-01 -9.84437091e-01\n", " -1.80362446e+00 -5.12692191e-01 -9.91806006e-01 -1.68002152e+00\n", " -1.53586192e+00 2.68399798e+00 -1.82601330e+00 1.64301497e+01\n", " -1.70359009e+00 -1.58728094e+00 -1.75394356e+00 -1.34535741e+00\n", " 1.10495209e-01 -9.26532156e-01 -3.82603590e-01 -1.89918146e+00\n", " -1.51184759e+00 -1.41230163e+00 -1.42684129e+00 3.74873110e+00\n", " -1.52708101e+00 -1.46017817e+00 8.08336879e+00 -1.28808406e+00\n", " -1.38479050e+00 -1.13317333e+00 -1.54291836e+00 -1.57391126e+00\n", " -1.70068092e+00 -1.55003997e+00 1.29949447e+01 -1.28224213e+00\n", " -1.75957248e+00 -3.46347655e-01 4.55078495e+01 1.62410914e+00\n", " -1.90951002e+00 -1.95422864e+00 -1.53798071e+00 -6.34843829e-01\n", " -2.09753123e+00 -1.25685698e+00 -1.36027213e+00 -1.66825055e+00\n", " -1.70448231e+00 -1.94765895e+00 -1.69766898e+00 -1.78561157e+00\n", " -1.60419197e+00 -1.40494019e+00 -6.96473887e-01 3.99984755e-01\n", " -1.11479103e+00 -1.29756435e+00 -1.72272990e+00 -1.90810839e+00\n", " -1.66195359e+00 -1.46676751e+00 -1.46267946e+00 -1.78615344e+00\n", " -1.17752944e+00 -1.37495124e+00 -7.94814627e-01 -1.64157573e+00\n", " -1.11586971e+00 -4.94737185e-01 -1.88415981e+00 -4.46839462e-01\n", " -1.24743318e+00 -1.93622801e+00 -1.84826398e+00 -1.34628300e+00\n", " -1.61211143e+00 -1.34232147e+00 -1.06422000e+00 -5.80303602e-01\n", " 7.48796414e+00 5.25168406e-01 1.07299085e+00 -1.13009502e+00\n", " -1.78922911e+00 -1.71047559e+00 -1.67376033e+00 -1.37914625e+00\n", " 1.52506273e+01 -1.45948639e+00 -8.00040238e-01 -7.59025274e-01\n", " -1.39121627e+00 6.10507397e-01 -1.41321334e+00 -1.87197330e+00\n", " -1.77123296e+00 -1.55599978e+00 -1.83748479e+00 -1.54290840e+00\n", " -1.29645972e+00 -1.69492558e+00 -1.65978325e+00 -1.69882178e+00\n", " -1.40547454e+00 -1.65781563e+00 -1.50002388e+00 8.02561464e-01\n", " -1.45477912e+00 -1.50131342e+00 -1.61676080e+00 9.57403312e-01\n", " -1.17694765e+00 -1.07033624e+00 1.03583989e+00 -4.34109768e-01\n", " -1.85113699e+00 -1.75541627e+00 -1.71456365e+00 -1.36478911e+00\n", " -1.78648122e+00 -1.67076913e+00 -1.47124550e+00 -1.46762853e+00\n", " -1.43353122e+00 -1.15059462e+00 -1.03921150e+00 7.77069994e+00\n", " -6.52537753e-01 -1.95359579e+00 4.92766321e-03 -1.63400977e+00\n", " -1.37799849e+00 7.73507960e+00 -1.07260906e+00 -5.51386965e-02\n", " -1.45188231e+00 -5.80168754e-01 -1.20330261e+00 9.09892623e-01\n", " -1.35067420e+00 -1.67099026e-01 2.93620071e+00 -8.90215432e-01\n", " -1.88844485e+00 -9.98748467e-01 -1.57422992e+00 5.19862332e+00\n", " -1.65280877e+00 5.05719830e-02 -1.64993873e+00 -1.73851840e+00\n", " 1.31559722e+01 -8.57390932e-01 -1.51660865e+00 -1.70640892e+00\n", " -1.32325122e+00 -4.10204037e-01 -1.76533400e+00 -1.32027774e+00\n", " -1.60915427e+00 -9.52788926e-01 5.03475436e+00 4.64582740e-01\n", " -9.04880738e-01 -1.27639527e+00 -1.39169112e+00 -1.72193367e+00\n", " -1.45730161e+00 -1.62931366e+00 -1.14442318e+00 -1.47834389e+00\n", " -1.60558557e+00 8.53798000e+00 -1.29385700e+00 -1.39512784e+00\n", " -2.46724543e-01 -7.95448169e-01 -1.05882990e+00 -1.77374134e+00\n", " -1.48409808e+00 -1.17958151e+00 -1.84262886e+00 -6.26374735e-01\n", " -1.73393168e+00 3.82567364e+00 -4.48226284e-01 -8.15663029e-01\n", " -1.05530367e+00 -2.12943893e+00 5.29590505e+01 -4.63005762e-01\n", " -1.70223872e+00 -1.72729716e+00 -1.63683781e+00 1.29198850e+00\n", " -1.60115424e+00 -1.63197052e+00 -1.14126707e+00 -8.96069743e-01\n", " -2.18061272e+00 0.00000000e+00]\n" ] } ], "source": [ "reg = linear_model.Ridge(alpha=lambda_)\n", "reg.fit(X_train, Y_train)\n", "w0 = reg.intercept_\n", "W = reg.coef_\n", "print(w0)\n", "print(W)" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1.18297512e+00 -7.42407867e-01 1.65214044e+00 -2.46034308e-02\n", " -4.69021814e-01 8.68726653e-01 6.72104627e+00 8.29711110e-01\n", " 1.30591417e+00 3.87780478e-01 2.73052330e-02 -8.52067029e-01\n", " 2.42915316e+00 6.22925442e-03 9.82701624e-01 1.85688847e+00\n", " 5.14099547e-01 6.61749796e-01 -1.34684735e-01 -1.17630981e-01]\n", "[0. 0. 1. 0. 0. 1. 6. 0. 4. 0. 0. 1. 2. 0. 5. 0. 1. 0. 0. 0.]\n", "Mean squared error: 29.32\n" ] } ], "source": [ "test_predict = reg.predict(X_test)\n", "print(test_predict[:20])\n", "print(Y_test[:20])\n", "print('Mean squared error: %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Second, we do the linear model with a logistic transformation." ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "c = 10\n", "Lambda: 0.3544455673970436\n", "[2.36868331 2.27255708 2.41760197 ... 2.97955573 2.31652432 3.87262415]\n", "[2.30258509 2.30258509 2.39789527 ... 2.30258509 2.30258509 3.8501476 ]\n", "[ 0.68331637 -0.29581649 1.21892376 ... 9.67907192 0.14036829\n", " 38.06835913]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 34.32\n", "c = 100\n", "Lambda: 0.3643858983763548\n", "[4.61541371 4.59918935 4.62094193 ... 4.70766727 4.60605685 4.97402359]\n", "[4.60517019 4.60517019 4.61512052 ... 4.60517019 4.60517019 4.91998093]\n", "[ 1.02961728 -0.59629879 1.58967752 ... 10.79340722 0.08870552\n", " 44.60755927]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 30.59\n", "c = 500\n", "Lambda: 0.4070142453219439\n", "[6.2168895 6.21320115 6.21788416 ... 6.23641285 6.21472897 6.30389222]\n", "[6.2146081 6.2146081 6.2166061 ... 6.2146081 6.2146081 6.28599809]\n", "[ 1.14200538 -0.70298117 1.64071919 ... 11.02210731 0.06044109\n", " 46.69563152]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 29.64\n", "c = 1000\n", "Lambda: 0.41842885079015846\n", "[6.90891612 6.90703328 6.90940047 ... 6.91874894 6.90781098 6.95374095]\n", "[6.90775528 6.90775528 6.90875478 ... 6.90775528 6.90775528 6.94408721]\n", "[ 1.16151439 -0.72173879 1.64655001 ... 11.05431636 0.05569874\n", " 47.0594116 ]\n", "[ 0. 0. 1. ... 0. 0. 37.]\n", "sanity\n", "[[23. 10. 0. ... 0. 0. 0.]\n", " [15. 2. 1. ... 0. 0. 0.]\n", " [14. 13. 0. ... 0. 0. 0.]\n", " ...\n", " [21. 11. 1. ... 0. 0. 0.]\n", " [18. 10. 0. ... 0. 0. 0.]\n", " [24. 6. 1. ... 0. 0. 0.]]\n", "Mean squared error : 29.49\n" ] } ], "source": [ "cs = [10, 100, 500, 1000]\n", "\n", "for c in cs:\n", " print(f\"c = {c}\")\n", " # logarithmic transformation\n", " Y_train_transf = np.log(c + Y_train)\n", " Y_test_transf = np.log(c + Y_test)\n", " \n", " # crossvalidation for lambda\n", " reg = linear_model.RidgeCV(alphas=np.logspace(-6, 6, 1000))\n", " reg.fit(X_train, Y_train_transf)\n", " print(f\"Lambda: {reg.alpha_}\")\n", " test_predict_transf = reg.predict(X_test)\n", " # transform the output back\n", " test_predict = np.exp(test_predict_transf) - c\n", " \n", " print(test_predict_transf)\n", " print(Y_test_transf)\n", " print(test_predict)\n", " print(Y_test)\n", " \n", " print('Mean squared error : %.2f'\n", " % mean_squared_error(Y_test, test_predict))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Last, we apply the two step model with Logistic Regression and the Poisson Regressor." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }