{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from nim_env import NimEnv\n", "from utils import compute_nim_sum, optimal_policy, random_policy\n", "import random\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "# from agent import RLAgent\n", "from sarsa import RLAgent\n", "from tqdm import tqdm\n", "import pandas as pd\n", "pd.options.display.float_format = '{:,.3f}'.format\n", "\n", "# plt figure setup\n", "from matplotlib import rc\n", "\n", "plt.rc('axes', labelsize=14) # fontsize of the x and y labels\n", "plt.rc('axes', titlesize=14)\n", "plt.rc('xtick', labelsize=13) # fontsize of the tick labels\n", "plt.rc('ytick', labelsize=13)\n", "plt.rc('legend', fontsize=14) # legend fontsize\n", "plt.rc('figure', titlesize=14) # fontsize of the figure title\n", "plt.rc('lines', markersize=7)\n", "plt.rc('lines', linewidth=2)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

\n", " \"ny\"\n", "

\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SARSA agent against fixed policy (teacher)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "def evaluate_fixed(agent, N):\n", " player0_ep_rewards = [0]*N\n", " player1_ep_rewards = [0]*N\n", " for i in range(N):\n", " done = False\n", " heaps = random.sample(range(1, 8), 3)\n", " env = NimEnv(heaps)\n", " turn = 0\n", " while not done:\n", " action = agent.decide(heaps)\n", " next_heaps, winner, reward, done, turn = env.step(action)\n", " if done:\n", " break\n", " adv_action = optimal_policy(next_heaps, randomness=0.2)\n", " nextnext_heaps, winner, adv_reward, done, _ = env.step(adv_action)\n", " heaps = nextnext_heaps\n", "\n", "\n", " if winner == 0:\n", " player0_ep_rewards.append(1) \n", " player1_ep_rewards.append(-1) \n", " else: # if winner == 1 \n", " player0_ep_rewards.append(-1)\n", " player1_ep_rewards.append(1) \n", "\n", " return np.mean(player0_ep_rewards), np.mean(player1_ep_rewards)\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "heaps = [7,7,7]\n", "env = NimEnv(heaps)\n", "\n", "\n", "def run_experiment(env, num_episodes):\n", " player0_rewards = []\n", " player1_rewards = []\n", " agent0 = RLAgent() # RL agent for player 0\n", " # agent1 = RLAgent() # RL agent for player 1\n", "\n", " for episode in tqdm(range(1, num_episodes+1)):\n", " heaps = env.reset() \n", " env = NimEnv(heaps)\n", " \n", " winner = None\n", " done = False\n", "\n", " action = None\n", " reward = None\n", " old_0 = (None, None) # old state/action for player0\n", " old_reward = None\n", " i = 0\n", " while not done:\n", " action = agent0.decide(heaps)\n", " next_heaps, winner, reward, done, _ = env.step(action)\n", " if i > 0:\n", " agent0.learn(old_0[0], old_0[1], old_reward, heaps, action)\n", " if done:\n", " agent0.learn(heaps, action, reward[0])\n", " break\n", " \n", " adv_action = optimal_policy(next_heaps, randomness=0.2)\n", " nextnext_heaps, winner, adv_reward, done, _ = env.step(adv_action)\n", " if done:\n", " agent0.learn(heaps, action, adv_reward[0])\n", " break\n", " \n", " old_0 = heaps, action\n", " old_reward = adv_reward[0]\n", " heaps = nextnext_heaps\n", " i += 1\n", "\n", " r0, r1 = evaluate_fixed(agent0, 10)\n", " player0_rewards.append(r0) \n", " player1_rewards.append(r1) \n", " \n", " return player0_rewards, player1_rewards, agent0" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50000/50000 [03:20<00:00, 249.48it/s]\n" ] } ], "source": [ "num_episodes = 50000\n", "r0, r1, agent = run_experiment(env, num_episodes)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "N=100\n", "r0_conv = np.convolve(r0, np.ones(N)/N, mode='valid')\n", "r1_conv = np.convolve(r1, np.ones(N)/N, mode='valid')\n", "fig = plt.figure(figsize=(12,8))\n", "plt.plot(range(1, len(r0_conv)+1), r0_conv, color='b', label='SARSA (player 0)')\n", "plt.plot(range(1, len(r1_conv)+1), r1_conv, color='r', label=r'Optimal ($\\epsilon = 0.2$)')\n", "plt.legend()\n", "plt.xlabel('episode')\n", "plt.ylabel('reward')\n", "plt.grid(True,'major',linestyle='-',linewidth=0.5)\n", "plt.grid(True,'minor',linestyle='--',linewidth=0.25) \n", "plt.title(r'SARSA against $\\epsilon$-greedy optimal policy')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q-table\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
(1, 1)(1, 2)(1, 3)(1, 4)(1, 5)(1, 6)(1, 7)(2, 1)(2, 2)(2, 3)...(2, 5)(2, 6)(2, 7)(3, 1)(3, 2)(3, 3)(3, 4)(3, 5)(3, 6)(3, 7)
(0, 0, 0)0.8890.7410.7940.6280.8200.7940.8370.8980.4450.824...0.6930.7240.6500.8590.5850.4880.5370.4320.7660.793
(0, 0, 1)-0.811-0.866-0.852-0.815-0.861-0.868-0.853-0.818-0.846-0.787...-0.819-0.820-0.8771.000-0.875-0.824-0.851-0.891-0.8860.000
(0, 0, 2)-0.458-0.400-0.513-0.474-0.512-0.468-0.571-0.442-0.462-0.510...-0.605-0.520-0.438-0.6751.000-0.440-0.485-0.4750.0000.000
(0, 0, 3)-0.467-0.4860.000-0.294-0.4030.000-0.110-0.469-0.226-0.166...-0.276-0.161-0.239-0.879-0.9151.000-0.0920.0000.0000.000
(0, 0, 4)-0.466-0.2650.0000.000-0.3370.0000.000-0.462-0.188-0.036...-0.184-0.176-0.129-0.620-0.824-0.7681.0000.0000.0000.000
(0, 0, 5)-0.0900.0000.000-0.1070.0000.0000.000-0.143-0.0440.000...0.000-0.0360.000-0.369-0.482-0.589-0.2001.0000.0000.000
(0, 0, 6)0.0000.0000.000-0.0360.0000.0000.000-0.1790.000-0.059...0.0000.0000.000-0.226-0.200-0.360-0.2000.0000.9990.000
(0, 0, 7)0.000-0.0110.0000.0000.0000.0000.0000.0000.006-0.036...0.0000.0000.0000.0000.000-0.360-0.200-0.200-0.3600.965
(0, 1, 0)-0.734-0.827-0.828-0.853-0.861-0.887-0.8881.000-0.712-0.883...-0.847-0.8820.000-0.766-0.857-0.873-0.864-0.856-0.859-0.868
(0, 1, 1)0.8000.8010.7980.8030.7880.7850.637-0.9920.7550.767...0.7630.7800.000-0.9940.7980.7450.7970.6980.7690.000
\n", "

10 rows × 21 columns

\n", "
" ], "text/plain": [ " (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (2, 1) \\\n", "(0, 0, 0) 0.889 0.741 0.794 0.628 0.820 0.794 0.837 0.898 \n", "(0, 0, 1) -0.811 -0.866 -0.852 -0.815 -0.861 -0.868 -0.853 -0.818 \n", "(0, 0, 2) -0.458 -0.400 -0.513 -0.474 -0.512 -0.468 -0.571 -0.442 \n", "(0, 0, 3) -0.467 -0.486 0.000 -0.294 -0.403 0.000 -0.110 -0.469 \n", "(0, 0, 4) -0.466 -0.265 0.000 0.000 -0.337 0.000 0.000 -0.462 \n", "(0, 0, 5) -0.090 0.000 0.000 -0.107 0.000 0.000 0.000 -0.143 \n", "(0, 0, 6) 0.000 0.000 0.000 -0.036 0.000 0.000 0.000 -0.179 \n", "(0, 0, 7) 0.000 -0.011 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 1, 0) -0.734 -0.827 -0.828 -0.853 -0.861 -0.887 -0.888 1.000 \n", "(0, 1, 1) 0.800 0.801 0.798 0.803 0.788 0.785 0.637 -0.992 \n", "\n", " (2, 2) (2, 3) ... (2, 5) (2, 6) (2, 7) (3, 1) (3, 2) \\\n", "(0, 0, 0) 0.445 0.824 ... 0.693 0.724 0.650 0.859 0.585 \n", "(0, 0, 1) -0.846 -0.787 ... -0.819 -0.820 -0.877 1.000 -0.875 \n", "(0, 0, 2) -0.462 -0.510 ... -0.605 -0.520 -0.438 -0.675 1.000 \n", "(0, 0, 3) -0.226 -0.166 ... -0.276 -0.161 -0.239 -0.879 -0.915 \n", "(0, 0, 4) -0.188 -0.036 ... -0.184 -0.176 -0.129 -0.620 -0.824 \n", "(0, 0, 5) -0.044 0.000 ... 0.000 -0.036 0.000 -0.369 -0.482 \n", "(0, 0, 6) 0.000 -0.059 ... 0.000 0.000 0.000 -0.226 -0.200 \n", "(0, 0, 7) 0.006 -0.036 ... 0.000 0.000 0.000 0.000 0.000 \n", "(0, 1, 0) -0.712 -0.883 ... -0.847 -0.882 0.000 -0.766 -0.857 \n", "(0, 1, 1) 0.755 0.767 ... 0.763 0.780 0.000 -0.994 0.798 \n", "\n", " (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) \n", "(0, 0, 0) 0.488 0.537 0.432 0.766 0.793 \n", "(0, 0, 1) -0.824 -0.851 -0.891 -0.886 0.000 \n", "(0, 0, 2) -0.440 -0.485 -0.475 0.000 0.000 \n", "(0, 0, 3) 1.000 -0.092 0.000 0.000 0.000 \n", "(0, 0, 4) -0.768 1.000 0.000 0.000 0.000 \n", "(0, 0, 5) -0.589 -0.200 1.000 0.000 0.000 \n", "(0, 0, 6) -0.360 -0.200 0.000 0.999 0.000 \n", "(0, 0, 7) -0.360 -0.200 -0.200 -0.360 0.965 \n", "(0, 1, 0) -0.873 -0.864 -0.856 -0.859 -0.868 \n", "(0, 1, 1) 0.745 0.797 0.698 0.769 0.000 \n", "\n", "[10 rows x 21 columns]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Q-table')\n", "agent.q.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SARSA agent against another SARSA agent" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "def evaluate(agent, N):\n", " player0_ep_rewards = [0]*N\n", " player1_ep_rewards = [0]*N\n", " for i in range(N):\n", " done = False\n", " heaps = random.sample(range(1, 8), 3)\n", " env = NimEnv(heaps)\n", " turn = 0\n", " while not done:\n", " action = agent[turn].decide(heaps)\n", " next_heaps, winner, reward, done, turn = env.step(action)\n", " if winner == 0:\n", " player0_ep_rewards.append(1) \n", " player1_ep_rewards.append(-1) \n", " else: # if winner == 1 \n", " player0_ep_rewards.append(-1)\n", " player1_ep_rewards.append(1) \n", "\n", " return np.mean(player0_ep_rewards), np.mean(player1_ep_rewards)\n", "\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "heaps = [7,7,7]\n", "env = NimEnv(heaps)\n", "\n", "\n", "def run_experiment(env, num_episodes):\n", " player0_rewards = []\n", " player1_rewards = []\n", " agent0 = RLAgent() # RL agent for player 0\n", " agent1 = RLAgent() # RL agent for player 1\n", "\n", " for episode in tqdm(range(1, num_episodes+1)):\n", " heaps = env.reset() \n", " env = NimEnv(heaps)\n", " \n", " winner = None\n", " done = False\n", "\n", " action = [None, None]\n", " reward = [None, None]\n", " old_0 = (None, None) # old state/action for player0\n", " old_1 = (None, None) # old state/action for player1\n", " old_reward = None\n", " i = 0\n", " while not done:\n", " action[0] = agent0.decide(heaps)\n", " next_heaps, winner, reward[0], done, _ = env.step(action[0])\n", " if i > 0:\n", " agent0.learn(old_0[0], old_0[1], old_reward[0], heaps, action[0])\n", " if done:\n", " agent0.learn(heaps, action[0], reward[0][0])\n", " if i > 0:\n", " agent1.learn(old_1[0], old_1[1], reward[0][1])\n", " break\n", " \n", " action[1] = agent1.decide(next_heaps)\n", " nextnext_heaps, winner, reward[1], done, _ = env.step(action[1])\n", " if i > 0:\n", " agent1.learn(old_1[0], old_1[1], old_reward[1], next_heaps, action[1])\n", " if done:\n", " agent0.learn(heaps, action[0], reward[1][0])\n", " agent1.learn(next_heaps, action[1], reward[1][1])\n", " break\n", " \n", "\n", " old_0 = heaps, action[0]\n", " old_1 = next_heaps, action[1]\n", " old_reward = reward[1]\n", " heaps = nextnext_heaps\n", " i += 1\n", "\n", " r0, r1 = evaluate([agent0, agent1], 10)\n", " player0_rewards.append(r0) \n", " player1_rewards.append(r1) \n", " \n", " return player0_rewards, player1_rewards, [agent0, agent1]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100000/100000 [10:33<00:00, 157.84it/s]\n" ] } ], "source": [ "num_episodes = 100_000\n", "r0, r1, agents = run_experiment(env, num_episodes)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "N=500\n", "r0_conv = np.convolve(r0, np.ones(N)/N, mode='valid')\n", "r1_conv = np.convolve(r1, np.ones(N)/N, mode='valid')\n", "fig = plt.figure(figsize=(12,8))\n", "plt.plot(range(1, len(r0_conv)+1), r0_conv, color='b', label='player0')\n", "plt.plot(range(1, len(r1_conv)+1), r1_conv, color='r', label='player1')\n", "plt.legend()\n", "plt.xlabel('episode')\n", "plt.ylabel('reward')\n", "plt.grid(True,'major',linestyle='-',linewidth=0.5)\n", "plt.grid(True,'minor',linestyle='--',linewidth=0.25) \n", "plt.title(r'SARSA against SARSA (first to play advantage?)')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q-table\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
(1, 1)(1, 2)(1, 3)(1, 4)(1, 5)(1, 6)(1, 7)(2, 1)(2, 2)(2, 3)...(2, 5)(2, 6)(2, 7)(3, 1)(3, 2)(3, 3)(3, 4)(3, 5)(3, 6)(3, 7)
(0, 0, 0)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000
(0, 0, 1)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.0001.0000.0000.0000.0000.0000.0000.000
(0, 0, 2)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.000-1.0001.0000.0000.0000.0000.0000.000
(0, 0, 3)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.000-0.977-0.9861.0000.0000.0000.0000.000
(0, 0, 4)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.000-0.790-0.790-0.7381.0000.0000.0000.000
(0, 0, 5)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.000-0.488-0.738-0.360-0.5901.0000.0000.000
(0, 0, 6)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.000-0.2000.0000.0000.0000.0000.9880.000
(0, 0, 7)0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000...0.0000.0000.000-0.200-0.200-0.200-0.360-0.360-0.2000.956
(0, 1, 0)0.0000.0000.0000.0000.0000.0000.0001.0000.0000.000...0.0000.0000.0000.0000.0000.0000.0000.0000.0000.000
(0, 1, 1)0.0000.0000.0000.0000.0000.0000.000-1.0000.0000.000...0.0000.0000.000-1.0000.0000.0000.0000.0000.0000.000
\n", "

10 rows × 21 columns

\n", "
" ], "text/plain": [ " (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (2, 1) \\\n", "(0, 0, 0) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 1) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 2) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 3) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 4) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 5) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 6) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 7) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 \n", "(0, 1, 0) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 \n", "(0, 1, 1) 0.000 0.000 0.000 0.000 0.000 0.000 0.000 -1.000 \n", "\n", " (2, 2) (2, 3) ... (2, 5) (2, 6) (2, 7) (3, 1) (3, 2) \\\n", "(0, 0, 0) 0.000 0.000 ... 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 1) 0.000 0.000 ... 0.000 0.000 0.000 1.000 0.000 \n", "(0, 0, 2) 0.000 0.000 ... 0.000 0.000 0.000 -1.000 1.000 \n", "(0, 0, 3) 0.000 0.000 ... 0.000 0.000 0.000 -0.977 -0.986 \n", "(0, 0, 4) 0.000 0.000 ... 0.000 0.000 0.000 -0.790 -0.790 \n", "(0, 0, 5) 0.000 0.000 ... 0.000 0.000 0.000 -0.488 -0.738 \n", "(0, 0, 6) 0.000 0.000 ... 0.000 0.000 0.000 -0.200 0.000 \n", "(0, 0, 7) 0.000 0.000 ... 0.000 0.000 0.000 -0.200 -0.200 \n", "(0, 1, 0) 0.000 0.000 ... 0.000 0.000 0.000 0.000 0.000 \n", "(0, 1, 1) 0.000 0.000 ... 0.000 0.000 0.000 -1.000 0.000 \n", "\n", " (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) \n", "(0, 0, 0) 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 1) 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 2) 0.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 3) 1.000 0.000 0.000 0.000 0.000 \n", "(0, 0, 4) -0.738 1.000 0.000 0.000 0.000 \n", "(0, 0, 5) -0.360 -0.590 1.000 0.000 0.000 \n", "(0, 0, 6) 0.000 0.000 0.000 0.988 0.000 \n", "(0, 0, 7) -0.200 -0.360 -0.360 -0.200 0.956 \n", "(0, 1, 0) 0.000 0.000 0.000 0.000 0.000 \n", "(0, 1, 1) 0.000 0.000 0.000 0.000 0.000 \n", "\n", "[10 rows x 21 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Q-table')\n", "agents[1].q.head(10)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The environment at beginning is: [3, 1, 5]\n", "agent took action [3, 3]\n", "[3, 1, 2]\n", "you took action [1, 1]\n", "[2, 1, 2]\n", "agent took action [2, 1]\n", "[2, 0, 2]\n", "you took action [1, 1]\n", "[1, 0, 2]\n", "agent took action [3, 1]\n", "[1, 0, 1]\n", "you took action [3, 1]\n", "[1, 0, 0]\n", "agent took action [1, 1]\n", "[0, 0, 0]\n", "agent win(s)!\n" ] } ], "source": [ "heaps = random.sample(range(1, 8), 3)\n", "env = NimEnv(heaps)\n", "agent = agents[1]\n", "print('The environment at beginning is:', end=' ')\n", "env.render(simple=True)\n", "done = False\n", "\n", "while not done:\n", " action = agent.exploit(heaps)\n", " heaps, winner, reward, done, turn = env.step(action)\n", " print('agent took action ', action)\n", " env.render(simple=True)\n", " if done:\n", " break\n", " entry = input('enter [heap, n_objects]: ')\n", " move = [int(entry[0]), int(entry[2])]\n", " heaps, winner, reward, done, turn = env.step(move)\n", " print('you took action ', move)\n", " env.render(simple=True)\n", " # print('\\n')\n", " # print('\\n')\n", " # print(f\"player {turn['next_turn']} turn\")\n", " \n", "\n", "# print('\\nHere is the reward: ', reward)\n", "winner = 'agent' if winner == 0 else 'you'\n", "print(f'{winner} win(s)!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Old" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_episodes = 500\n", "num_seed = 10\n", "heaps = random.sample(range(1, 8), 3)\n", "env = NimEnv(heaps)\n", "\n", "\n", "def run_experiment(env, num_episodes):\n", " player0_rewards = []\n", " player1_rewards = []\n", " agent = [RLAgent(), RLAgent()] # agents for player 0,1\n", " # agent = RLAgent() # player 0 plays against e-greedy optimal policy\n", " # agent0 = RLAgent() # RL agent for player 0\n", " # agent1 = RLAgent() # RL agent for player 1\n", " actions = [None, None]\n", " reward = [None, None]\n", "\n", " for episode in tqdm(range(1, num_episodes+1)):\n", " player0_episode_reward = []\n", " player1_episode_reward = []\n", "\n", " for seed in range(num_seed):\n", " heaps = env.reset(seed) # fixed seeds to evaluate performance on same set of initial conditions\n", " env = NimEnv(heaps)\n", " winner = None\n", " done = False\n", " old_heaps = None\n", " turn = 0\n", " next_turn = None\n", " action[0] = agent[0].decide(heaps)\n", " while not done:\n", " # print(heaps)\n", " # print(action)\n", " next_heaps = None\n", " nextnext_heaps = None\n", " next_heaps, winner, reward0, done, next_turn = env.step(action0) # agent0 takes action a\n", " if done:\n", " agent[turn].learn(heaps, action[turn], reward[turn][0])\n", " if old_heaps is not None:\n", " agent[next_turn].learn(old_heaps, action1, reward[turn][1])\n", " break\n", " action1 = agent1.decide(next_heaps)\n", " nextnext_heaps, winner, reward1, done, _ = env.step(action1) # agent1 takes action a'\n", " if done:\n", " agent0.learn(heaps, action0, reward1[0])\n", " agent1.learn(next_heaps, action1, reward1[1])\n", " break\n", " next_action0 = agent0.decide(nextnext_heaps)\n", " next3_heaps, winner, reward0, done, _ = env.step(next_action0)\n", " agent0.learn(heaps, action0, reward0[0], nextnext_heaps, next_action0)\n", " if done:\n", " agent1.learn(next_heaps, action1, reward0[1])\n", " break\n", " next_action1 = agent1.decide(next3_heaps)\n", " agent1.learn(next_heaps, action1, reward1[1], next3_heaps, next_action1)\n", " next4_heaps, winner, reward1, done, _ = env.step(next_action1)\n", " if done:\n", " agent0.learn(nextnext_heaps, next_action0, reward1[0])\n", " agent1.learn(next3_heaps, next_action1, reward1[1])\n", " break\n", " nextnext_action0 = agent0.decide(next4_heaps)\n", "\n", " old_heaps = next3_heaps.copy()\n", " action1 = next_action1.copy()\n", " heaps = next4_heaps.copy()\n", " action0 = nextnext_action0.copy()\n", "\n", " if winner == 0:\n", " player0_episode_reward.append(1) \n", " player1_episode_reward.append(-1) \n", " else: # if winner == 1 \n", " player0_episode_reward.append(-1)\n", " player1_episode_reward.append(1) \n", " \n", " player0_rewards.append(np.mean(player0_episode_reward)) \n", " player1_rewards.append(np.mean(player1_episode_reward)) \n", " return player0_rewards, player1_rewards, agent\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_episodes = 5000\n", "num_seed = 100\n", "heaps = random.sample(range(1, 8), 3)\n", "env = NimEnv(heaps)\n", "\n", "\n", "def run_experiment(env, num_episodes):\n", " player0_rewards = []\n", " player1_rewards = []\n", " # agent = [RLAgent(), RLAgent()] # agents for player 0,1\n", " agent = RLAgent() # player 0 plays against e-greedy optimal policy\n", " # agent0 = RLAgent() # RL agent for player 0\n", " # agent1 = RLAgent() # RL agent for player 1\n", "\n", " for episode in tqdm(range(1, num_episodes+1)):\n", " player0_episode_reward = []\n", " player1_episode_reward = []\n", "\n", " for seed in range(num_seed):\n", " heaps = env.reset(seed) # fixed seeds to evaluate performance on same set of initial conditions\n", " #heaps = [7,7,7]\n", " env = NimEnv(heaps)\n", " winner = None\n", " done = False\n", " action = agent.decide(heaps)\n", " while not done:\n", " # print(heaps)\n", " # print(action)\n", " next_heaps = None\n", " nextnext_heaps = None\n", " next_heaps, winner, rewards, done, _ = env.step(action)\n", " if done:\n", " agent.learn(heaps, action, rewards[0])\n", " break\n", " adv_action = optimal_policy(next_heaps, randomness=0.2)\n", " # print('nextheaps1=', next_heaps)\n", " # print('adv_action=', adv_action)\n", " nextnext_heaps, winner, adv_reward, done, _ = env.step(adv_action)\n", " if done:\n", " agent.learn(heaps, action, adv_reward[0]) #adv_reward[0]\n", " break\n", " nextnext_action = agent.decide(nextnext_heaps)\n", " # print(f'Q({heaps}, {action}) = ', agent.getQ(heaps, action))\n", " agent.learn(heaps, action, rewards[0], nextnext_heaps, nextnext_action)\n", " # print('Q(heaps, action) = ', agent.getQ(heaps, action))\n", "\n", " heaps = nextnext_heaps.copy()\n", " action = nextnext_action.copy()\n", "\n", " if winner == 0:\n", " player0_episode_reward.append(1) \n", " player1_episode_reward.append(-1) \n", " else: # if winner == 1 \n", " player0_episode_reward.append(-1)\n", " player1_episode_reward.append(1) \n", " \n", " player0_rewards.append(np.mean(player0_episode_reward)) \n", " player1_rewards.append(np.mean(player1_episode_reward)) \n", " return player0_rewards, player1_rewards, agent\n" ] } ], "metadata": { "interpreter": { "hash": "4369559244255f10d34bca352df9b3f8794934e60b17f3451fa3b0f2f96527a8" }, "kernelspec": { "display_name": "Python 3.8.8 64-bit ('base': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }