{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from nim_env import NimEnv\n", "from utils import compute_nim_sum, optimal_policy, random_policy\n", "import random\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "# from agent import RLAgent\n", "from sarsa import RLAgent\n", "from tqdm import tqdm\n", "import pandas as pd\n", "pd.options.display.float_format = '{:,.3f}'.format\n", "\n", "# plt figure setup\n", "from matplotlib import rc\n", "\n", "plt.rc('axes', labelsize=14) # fontsize of the x and y labels\n", "plt.rc('axes', titlesize=14)\n", "plt.rc('xtick', labelsize=13) # fontsize of the tick labels\n", "plt.rc('ytick', labelsize=13)\n", "plt.rc('legend', fontsize=14) # legend fontsize\n", "plt.rc('figure', titlesize=14) # fontsize of the figure title\n", "plt.rc('lines', markersize=7)\n", "plt.rc('lines', linewidth=2)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", " \n", "
\n", "\n", " | (1, 1) | \n", "(1, 2) | \n", "(1, 3) | \n", "(1, 4) | \n", "(1, 5) | \n", "(1, 6) | \n", "(1, 7) | \n", "(2, 1) | \n", "(2, 2) | \n", "(2, 3) | \n", "... | \n", "(2, 5) | \n", "(2, 6) | \n", "(2, 7) | \n", "(3, 1) | \n", "(3, 2) | \n", "(3, 3) | \n", "(3, 4) | \n", "(3, 5) | \n", "(3, 6) | \n", "(3, 7) | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
(0, 0, 0) | \n", "0.889 | \n", "0.741 | \n", "0.794 | \n", "0.628 | \n", "0.820 | \n", "0.794 | \n", "0.837 | \n", "0.898 | \n", "0.445 | \n", "0.824 | \n", "... | \n", "0.693 | \n", "0.724 | \n", "0.650 | \n", "0.859 | \n", "0.585 | \n", "0.488 | \n", "0.537 | \n", "0.432 | \n", "0.766 | \n", "0.793 | \n", "
(0, 0, 1) | \n", "-0.811 | \n", "-0.866 | \n", "-0.852 | \n", "-0.815 | \n", "-0.861 | \n", "-0.868 | \n", "-0.853 | \n", "-0.818 | \n", "-0.846 | \n", "-0.787 | \n", "... | \n", "-0.819 | \n", "-0.820 | \n", "-0.877 | \n", "1.000 | \n", "-0.875 | \n", "-0.824 | \n", "-0.851 | \n", "-0.891 | \n", "-0.886 | \n", "0.000 | \n", "
(0, 0, 2) | \n", "-0.458 | \n", "-0.400 | \n", "-0.513 | \n", "-0.474 | \n", "-0.512 | \n", "-0.468 | \n", "-0.571 | \n", "-0.442 | \n", "-0.462 | \n", "-0.510 | \n", "... | \n", "-0.605 | \n", "-0.520 | \n", "-0.438 | \n", "-0.675 | \n", "1.000 | \n", "-0.440 | \n", "-0.485 | \n", "-0.475 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 3) | \n", "-0.467 | \n", "-0.486 | \n", "0.000 | \n", "-0.294 | \n", "-0.403 | \n", "0.000 | \n", "-0.110 | \n", "-0.469 | \n", "-0.226 | \n", "-0.166 | \n", "... | \n", "-0.276 | \n", "-0.161 | \n", "-0.239 | \n", "-0.879 | \n", "-0.915 | \n", "1.000 | \n", "-0.092 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 4) | \n", "-0.466 | \n", "-0.265 | \n", "0.000 | \n", "0.000 | \n", "-0.337 | \n", "0.000 | \n", "0.000 | \n", "-0.462 | \n", "-0.188 | \n", "-0.036 | \n", "... | \n", "-0.184 | \n", "-0.176 | \n", "-0.129 | \n", "-0.620 | \n", "-0.824 | \n", "-0.768 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 5) | \n", "-0.090 | \n", "0.000 | \n", "0.000 | \n", "-0.107 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.143 | \n", "-0.044 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "-0.036 | \n", "0.000 | \n", "-0.369 | \n", "-0.482 | \n", "-0.589 | \n", "-0.200 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 6) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.036 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.179 | \n", "0.000 | \n", "-0.059 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.226 | \n", "-0.200 | \n", "-0.360 | \n", "-0.200 | \n", "0.000 | \n", "0.999 | \n", "0.000 | \n", "
(0, 0, 7) | \n", "0.000 | \n", "-0.011 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.006 | \n", "-0.036 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.360 | \n", "-0.200 | \n", "-0.200 | \n", "-0.360 | \n", "0.965 | \n", "
(0, 1, 0) | \n", "-0.734 | \n", "-0.827 | \n", "-0.828 | \n", "-0.853 | \n", "-0.861 | \n", "-0.887 | \n", "-0.888 | \n", "1.000 | \n", "-0.712 | \n", "-0.883 | \n", "... | \n", "-0.847 | \n", "-0.882 | \n", "0.000 | \n", "-0.766 | \n", "-0.857 | \n", "-0.873 | \n", "-0.864 | \n", "-0.856 | \n", "-0.859 | \n", "-0.868 | \n", "
(0, 1, 1) | \n", "0.800 | \n", "0.801 | \n", "0.798 | \n", "0.803 | \n", "0.788 | \n", "0.785 | \n", "0.637 | \n", "-0.992 | \n", "0.755 | \n", "0.767 | \n", "... | \n", "0.763 | \n", "0.780 | \n", "0.000 | \n", "-0.994 | \n", "0.798 | \n", "0.745 | \n", "0.797 | \n", "0.698 | \n", "0.769 | \n", "0.000 | \n", "
10 rows × 21 columns
\n", "\n", " | (1, 1) | \n", "(1, 2) | \n", "(1, 3) | \n", "(1, 4) | \n", "(1, 5) | \n", "(1, 6) | \n", "(1, 7) | \n", "(2, 1) | \n", "(2, 2) | \n", "(2, 3) | \n", "... | \n", "(2, 5) | \n", "(2, 6) | \n", "(2, 7) | \n", "(3, 1) | \n", "(3, 2) | \n", "(3, 3) | \n", "(3, 4) | \n", "(3, 5) | \n", "(3, 6) | \n", "(3, 7) | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
(0, 0, 0) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 1) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 2) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-1.000 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 3) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.977 | \n", "-0.986 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 4) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.790 | \n", "-0.790 | \n", "-0.738 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 5) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.488 | \n", "-0.738 | \n", "-0.360 | \n", "-0.590 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 0, 6) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.200 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.988 | \n", "0.000 | \n", "
(0, 0, 7) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-0.200 | \n", "-0.200 | \n", "-0.200 | \n", "-0.360 | \n", "-0.360 | \n", "-0.200 | \n", "0.956 | \n", "
(0, 1, 0) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "1.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
(0, 1, 1) | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-1.000 | \n", "0.000 | \n", "0.000 | \n", "... | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "-1.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "0.000 | \n", "
10 rows × 21 columns
\n", "