diff --git a/GetaCM.ipynb b/GetaCM.ipynb index bd6d2de..a99b8d9 100644 --- a/GetaCM.ipynb +++ b/GetaCM.ipynb @@ -1,374 +1,426 @@ { "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 219, "metadata": {}, "outputs": [], "source": [ "from glob import glob\n", "import numpy as np" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 220, "metadata": {}, "outputs": [], "source": [ "from rdkit import Chem" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 221, "metadata": {}, "outputs": [], "source": [ "target_xyzs = sorted(glob(\"targets/*.xyz\"))" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 222, "metadata": {}, "outputs": [], "source": [ "def read_sdf(sdf):\n", " with open(sdf, \"r\") as f:\n", " txt = f.read().rstrip()\n", " return txt" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 223, "metadata": {}, "outputs": [], "source": [ "def get_ncharges_coords(sdf):\n", " mol = Chem.MolFromMolBlock(sdf)\n", " #mol = Chem.AddHs(mol)\n", " # rdkit molobj\n", " ncharges = [atom.GetAtomicNum() for atom in mol.GetAtoms()]\n", " conf = mol.GetConformer()\n", " coords = np.asarray(conf.GetPositions())\n", " return ncharges, coords" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 224, "metadata": {}, "outputs": [], "source": [ - "def cutoff_func(R_ij, central_cutoff=4.8, central_decay=0.03):\n", + "def cutoff_func(R_ij, central_cutoff=4.8, central_decay=1):\n", " if R_ij <= (central_cutoff - central_decay):\n", - " # print('1')\n", " func = 1.\n", " elif ((central_cutoff - central_decay) < R_ij) and (R_ij <= (central_cutoff + central_decay)):\n", - " # print('function')\n", - " func = 0.5 * (1. + np.cos((np.pi * R_ij - central_cutoff + central_decay)))\n", + " func = 0.5 * (1. + np.cos((np.pi * R_ij - central_cutoff + central_decay)/central_decay))\n", " else:\n", - " # print('zero')\n", " func = 0.\n", " return func" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 235, "metadata": {}, "outputs": [], "source": [ - "def get_atomic_CM(ncharges, coords, max_natoms, central_cutoff=4.8, central_decay=0.03):\n", + "def get_atomic_CM(ncharges, coords, max_natoms, central_cutoff=4.8, central_decay=1):\n", " size = int((max_natoms + 1)*max_natoms / 2)\n", " rep = np.zeros((len(ncharges), size))\n", " \n", " # central atom loop\n", " for k in range(len(ncharges)):\n", " M = np.zeros((len(ncharges), len(ncharges)))\n", " for i in range(len(ncharges)):\n", " R_ik = np.linalg.norm(coords[i]-coords[k])\n", " # print('R_ik', R_ik)\n", " f_ik = cutoff_func(R_ik, central_cutoff=central_cutoff,\n", " central_decay=central_decay)\n", - " for j in range(i):\n", - " if i == j:\n", - " M[i,j] = 0.5 * ncharges[i]**2.4 * f_ik**2\n", - " M[j,i] = M[i,j]\n", - " \n", - " else:\n", - " R_jk = np.linalg.norm(coords[j]-coords[k])\n", - " # print('R_jk', R_jk)\n", - " f_jk = cutoff_func(R_jk, central_cutoff=central_cutoff,\n", - " central_decay=central_decay)\n", - " R_ij = np.linalg.norm(coords[i]-coords[j])\n", - " # print('R_ij', R_ij)\n", - " f_ij = cutoff_func(R_ij, central_cutoff=central_cutoff,\n", - " central_decay=central_decay)\n", - " M[i,j] = (ncharges[i]*ncharges[j]/R_ij)*f_ik*f_jk*f_ij\n", - " M[j,i] = M[i,j]\n", - " \n", + " for j in range(len(ncharges)):\n", + " if i <=j:\n", + " if i == j:\n", + " M[i,j] = 0.5 * ncharges[i]**2.4 * f_ik**2\n", + " M[j,i] = M[i,j]\n", + "\n", + " else:\n", + " R_jk = np.linalg.norm(coords[j]-coords[k])\n", + " # print('R_jk', R_jk)\n", + " f_jk = cutoff_func(R_jk, central_cutoff=central_cutoff,\n", + " central_decay=central_decay)\n", + " R_ij = np.linalg.norm(coords[i]-coords[j])\n", + " # print('R_ij', R_ij)\n", + " f_ij = cutoff_func(R_ij, central_cutoff=central_cutoff,\n", + " central_decay=central_decay)\n", + " M[i,j] = (ncharges[i]*ncharges[j]/R_ij)*f_ik*f_jk*f_ij\n", + " M[j,i] = M[i,j]\n", + "\n", + "\n", " # concat upper triangular and diagonal\n", - " upper_triang = np.triu(M)\n", - " non_zero_i, non_zero_j = np.nonzero(upper_triang)\n", - " unpadded_rep = upper_triang[non_zero_i, non_zero_j]\n", + " upper_triang = M[np.triu_indices(len(M))]\n", + " s_upper_triang = np.sort(upper_triang)[::-1]\n", + " \n", " # pad to full size\n", - " n_zeros = size - len(unpadded_rep)\n", + " n_zeros = size - len(s_upper_triang)\n", " zeros = np.zeros(n_zeros)\n", - " rep[k] = np.concatenate((unpadded_rep, zeros))\n", - " \n", + " rep[k] = np.concatenate((s_upper_triang, zeros))\n", + "\n", " return rep" ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 236, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['targets/qm9.sdf', 'targets/vitc.sdf', 'targets/vitd.sdf']" ] }, - "execution_count": 37, + "execution_count": 236, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_files = sorted(glob(\"targets/*.sdf\"))\n", "target_files" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 237, "metadata": {}, "outputs": [], "source": [ "target_sdfs = [read_sdf(x) for x in target_files]" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 238, "metadata": {}, "outputs": [], "source": [ "conf_data = [get_ncharges_coords(x) for x in target_sdfs]" ] }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 239, "metadata": {}, "outputs": [], "source": [ "ncharges_list, coords_list = zip(*conf_data)" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 240, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[9, 12, 28]" ] }, - "execution_count": 41, + "execution_count": 240, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sizes = [len(x) for x in ncharges_list]\n", "sizes" ] }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 241, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/ipykernel_launcher.py:4: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " after removing the cwd from sys.path.\n" ] } ], "source": [ "target_reps = np.array(\n", "[np.array(get_atomic_CM(np.array(ncharges_list[i]), np.array(coords_list[i]),\n", " max_natoms=sizes[i]))\n", "for i in range(len(ncharges_list))])" ] }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 242, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([[39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 11.22347027,\n", - " 14.10327203, 32.57835982, 15.64275949, 17.59643875, 24.93677728,\n", - " 30.00133532, 17.36551167, 17.48814725, 17.13482138, 13.22143102,\n", - " 32.63942034, 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 10.24376593,\n", - " 11.5318783 , 11.22347027, 14.10327203, 32.57835982, 15.64275949,\n", - " 14.05740332, 18.07046347, 17.59643875, 24.93677728, 30.00133532,\n", - " 24.37881483, 23.29821822, 17.36551167, 17.48814725, 40.31759536,\n", - " 34.11172854, 17.13482138, 13.22143102, 28.78715844, 16.01650232,\n", - " 12.2823109 , 40.74474394, 21.22940297, 32.63942034, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 10.24376593,\n", - " 11.5318783 , 11.22347027, 14.10327203, 32.57835982, 15.64275949,\n", - " 14.05740332, 18.07046347, 17.59643875, 24.93677728, 30.00133532,\n", - " 24.37881483, 23.29821822, 17.36551167, 17.48814725, 40.31759536,\n", - " 34.11172854, 17.13482138, 13.22143102, 28.78715844, 16.01650232,\n", - " 12.2823109 , 40.74474394, 21.22940297, 32.63942034, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 10.24376593,\n", - " 11.5318783 , 11.22347027, 14.10327203, 32.57835982, 15.64275949,\n", - " 14.05740332, 18.07046347, 17.59643875, 24.93677728, 30.00133532,\n", - " 24.37881483, 23.29821822, 17.36551167, 17.48814725, 40.31759536,\n", - " 34.11172854, 17.13482138, 13.22143102, 28.78715844, 16.01650232,\n", - " 12.2823109 , 40.74474394, 21.22940297, 32.63942034, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 10.24376593,\n", - " 11.5318783 , 11.22347027, 14.10327203, 32.57835982, 15.64275949,\n", - " 14.05740332, 18.07046347, 17.59643875, 24.93677728, 30.00133532,\n", - " 24.37881483, 23.29821822, 17.36551167, 17.48814725, 40.31759536,\n", - " 34.11172854, 17.13482138, 13.22143102, 28.78715844, 16.01650232,\n", - " 12.2823109 , 40.74474394, 21.22940297, 32.63942034, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [23.89446837, 17.57023229, 9.7738549 , 10.24376593, 11.5318783 ,\n", - " 11.22347027, 14.10327203, 32.57835982, 15.64275949, 14.05740332,\n", - " 18.07046347, 17.59643875, 24.93677728, 30.00133532, 24.37881483,\n", - " 23.29821822, 17.36551167, 17.48814725, 40.31759536, 34.11172854,\n", - " 17.13482138, 13.22143102, 28.78715844, 16.01650232, 12.2823109 ,\n", - " 40.74474394, 21.22940297, 32.63942034, 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [23.89446837, 17.57023229, 9.7738549 , 10.24376593, 11.5318783 ,\n", - " 11.22347027, 14.10327203, 32.57835982, 15.64275949, 14.05740332,\n", - " 18.07046347, 17.59643875, 24.93677728, 30.00133532, 24.37881483,\n", - " 23.29821822, 17.36551167, 17.48814725, 40.31759536, 34.11172854,\n", - " 17.13482138, 13.22143102, 28.78715844, 16.01650232, 12.2823109 ,\n", - " 40.74474394, 21.22940297, 32.63942034, 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 10.24376593,\n", - " 11.5318783 , 11.22347027, 14.10327203, 32.57835982, 15.64275949,\n", - " 14.05740332, 18.07046347, 17.59643875, 24.93677728, 30.00133532,\n", - " 24.37881483, 23.29821822, 17.36551167, 17.48814725, 40.31759536,\n", - " 34.11172854, 17.13482138, 13.22143102, 28.78715844, 16.01650232,\n", - " 12.2823109 , 40.74474394, 21.22940297, 32.63942034, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ],\n", - " [39.89268481, 20.19979599, 16.0445292 , 10.27020053, 13.46303348,\n", - " 16.70306767, 23.89446837, 17.57023229, 9.7738549 , 10.24376593,\n", - " 11.5318783 , 11.22347027, 14.10327203, 32.57835982, 15.64275949,\n", - " 14.05740332, 18.07046347, 17.59643875, 24.93677728, 30.00133532,\n", - " 24.37881483, 23.29821822, 17.36551167, 17.48814725, 40.31759536,\n", - " 34.11172854, 17.13482138, 13.22143102, 28.78715844, 16.01650232,\n", - " 12.2823109 , 40.74474394, 21.22940297, 32.63942034, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ]])" + "array([[7.35166947e+01, 5.33587074e+01, 4.35877623e+01, 3.98926848e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.25783598e+01,\n", + " 2.49367773e+01, 2.38944684e+01, 2.01997960e+01, 1.79395692e+01,\n", + " 1.75702323e+01, 1.74881473e+01, 1.67030677e+01, 1.63465866e+01,\n", + " 1.60445292e+01, 1.41032720e+01, 1.39142112e+01, 1.37269310e+01,\n", + " 1.33142263e+01, 1.16565056e+01, 1.05343157e+01, 9.57815448e+00,\n", + " 7.75795874e+00, 7.73149910e+00, 7.25911544e+00, 6.94206566e+00,\n", + " 6.07433646e+00, 5.86750965e+00, 4.33751746e+00, 2.07719952e+00,\n", + " 2.04788044e+00, 2.02269137e+00, 1.10413147e+00, 2.01239801e-01,\n", + " 1.93433788e-01, 1.08491628e-01, 1.07067836e-01, 6.91988063e-02,\n", + " 4.87267466e-02, 4.68841852e-02, 4.26710041e-02, 2.02837470e-03,\n", + " 5.11782947e-04],\n", + " [7.35166947e+01, 5.33587074e+01, 5.33587074e+01, 3.98926848e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.26394203e+01, 3.25783598e+01, 3.00013353e+01, 2.49367773e+01,\n", + " 2.38944684e+01, 2.01997960e+01, 1.86528773e+01, 1.75964388e+01,\n", + " 1.75702323e+01, 1.74881473e+01, 1.73655117e+01, 1.71348214e+01,\n", + " 1.67030677e+01, 1.60445292e+01, 1.57357844e+01, 1.56427595e+01,\n", + " 1.41032720e+01, 1.32214310e+01, 1.12788235e+01, 1.12234703e+01,\n", + " 9.77385490e+00, 7.41001169e+00, 6.50363736e+00, 4.55778959e+00,\n", + " 2.22400463e+00, 2.19261343e+00, 1.18216544e+00, 2.23588601e-01,\n", + " 1.87189633e-01, 1.27850012e-01, 1.16497296e-01, 9.91624747e-02,\n", + " 8.30069332e-02, 7.30849704e-02, 5.52886812e-02, 2.21381896e-03,\n", + " 3.47261135e-04],\n", + " [7.35166947e+01, 7.35166947e+01, 7.35166947e+01, 5.33587074e+01,\n", + " 5.33587074e+01, 4.07447439e+01, 4.03175954e+01, 3.98926848e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.41117285e+01, 3.26394203e+01, 3.25783598e+01, 3.00013353e+01,\n", + " 2.87871584e+01, 2.49367773e+01, 2.43788148e+01, 2.38944684e+01,\n", + " 2.32982182e+01, 2.12294030e+01, 2.01997960e+01, 1.80704635e+01,\n", + " 1.75964388e+01, 1.75702323e+01, 1.74881473e+01, 1.73655117e+01,\n", + " 1.71348214e+01, 1.67030677e+01, 1.60445292e+01, 1.60165023e+01,\n", + " 1.56427595e+01, 1.41032720e+01, 1.40574033e+01, 1.32214310e+01,\n", + " 1.12234703e+01, 1.00753042e+01, 9.77385490e+00, 4.80711459e+00,\n", + " 4.73926353e+00, 4.55778959e+00, 2.55521264e+00, 8.30069332e-02,\n", + " 6.32816968e-02],\n", + " [7.35166947e+01, 7.35166947e+01, 7.35166947e+01, 5.33587074e+01,\n", + " 5.33587074e+01, 4.07447439e+01, 4.03175954e+01, 3.98926848e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.41117285e+01, 3.26394203e+01, 3.25783598e+01, 3.00013353e+01,\n", + " 2.87871584e+01, 2.49367773e+01, 2.43788148e+01, 2.38944684e+01,\n", + " 2.32982182e+01, 2.12294030e+01, 2.01997960e+01, 1.80704635e+01,\n", + " 1.75964388e+01, 1.75702323e+01, 1.74881473e+01, 1.73655117e+01,\n", + " 1.71348214e+01, 1.67030677e+01, 1.60445292e+01, 1.60165023e+01,\n", + " 1.56427595e+01, 1.41032720e+01, 1.40574033e+01, 1.32214310e+01,\n", + " 1.12234703e+01, 1.00753042e+01, 9.77385490e+00, 4.80711459e+00,\n", + " 4.73926353e+00, 4.55778959e+00, 2.55521264e+00, 8.30069332e-02,\n", + " 6.32816968e-02],\n", + " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", + " 4.07447439e+01, 4.03175954e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", + " 3.25783598e+01, 3.00013353e+01, 2.87871584e+01, 2.49367773e+01,\n", + " 2.43788148e+01, 2.38944684e+01, 2.32982182e+01, 2.12294030e+01,\n", + " 1.80704635e+01, 1.77038864e+01, 1.75964388e+01, 1.75702323e+01,\n", + " 1.74881473e+01, 1.73655117e+01, 1.71348214e+01, 1.60165023e+01,\n", + " 1.56427595e+01, 1.44789368e+01, 1.41032720e+01, 1.40574033e+01,\n", + " 1.32214310e+01, 1.12234703e+01, 9.77385490e+00, 8.96442281e+00,\n", + " 7.41261748e+00, 7.12036615e+00, 4.73926353e+00, 4.47129696e+00,\n", + " 2.55521264e+00, 2.13333876e+00, 2.02269137e+00, 6.32816968e-02,\n", + " 3.68374634e-02],\n", + " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", + " 4.07447439e+01, 4.03175954e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.41117285e+01, 3.25783598e+01, 3.00013353e+01, 2.87871584e+01,\n", + " 2.43788148e+01, 2.32982182e+01, 1.80704635e+01, 1.75964388e+01,\n", + " 1.73655117e+01, 1.71348214e+01, 1.60165023e+01, 1.56427595e+01,\n", + " 1.40574033e+01, 1.37269310e+01, 1.10547413e+01, 8.72852223e+00,\n", + " 8.12884262e+00, 7.97513819e+00, 7.88924475e+00, 6.93299229e+00,\n", + " 6.79030682e+00, 5.19252233e+00, 5.18784853e+00, 4.52185987e+00,\n", + " 4.41656617e+00, 4.35363388e+00, 3.63823513e+00, 2.75058724e+00,\n", + " 2.19261343e+00, 2.07719952e+00, 1.96946384e+00, 1.59524329e+00,\n", + " 1.50154012e+00, 1.35743192e+00, 5.31586580e-01, 3.58680782e-02,\n", + " 2.92771857e-02],\n", + " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", + " 4.35877623e+01, 4.07447439e+01, 4.03175954e+01, 3.68581052e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", + " 3.25783598e+01, 3.00013353e+01, 2.87871584e+01, 2.49367773e+01,\n", + " 2.43788148e+01, 2.32982182e+01, 2.12294030e+01, 1.80704635e+01,\n", + " 1.75964388e+01, 1.74881473e+01, 1.73655117e+01, 1.71348214e+01,\n", + " 1.60165023e+01, 1.56427595e+01, 1.55537919e+01, 1.40574033e+01,\n", + " 1.32214310e+01, 1.28613199e+01, 1.23542469e+01, 7.75795874e+00,\n", + " 3.70146609e+00, 3.50948647e+00, 2.55521264e+00, 1.68562371e-01,\n", + " 1.31121962e-01, 9.64174338e-02, 7.73923346e-02, 6.39151288e-02,\n", + " 6.15892940e-02, 5.36344649e-02, 2.60069201e-02, 1.10991350e-03,\n", + " 3.47261135e-04],\n", + " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", + " 4.07447439e+01, 4.03175954e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", + " 3.25783598e+01, 3.00013353e+01, 2.87871584e+01, 2.49367773e+01,\n", + " 2.43788148e+01, 2.38944684e+01, 2.32982182e+01, 2.12294030e+01,\n", + " 1.80704635e+01, 1.75964388e+01, 1.75702323e+01, 1.74881473e+01,\n", + " 1.73655117e+01, 1.71348214e+01, 1.60165023e+01, 1.56427595e+01,\n", + " 1.41032720e+01, 1.40574033e+01, 1.32214310e+01, 1.12234703e+01,\n", + " 9.77385490e+00, 4.73926353e+00, 2.55521264e+00, 2.45960127e-01,\n", + " 1.24542743e-01, 1.02983508e-01, 9.89232601e-02, 6.32816968e-02,\n", + " 6.21197370e-02, 2.96384793e-02, 2.81012549e-02, 2.79465922e-03,\n", + " 5.11782947e-04],\n", + " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", + " 4.07447439e+01, 3.98926848e+01, 3.68581052e+01, 3.68581052e+01,\n", + " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", + " 3.25783598e+01, 3.00013353e+01, 2.49367773e+01, 2.38944684e+01,\n", + " 2.32982182e+01, 2.12294030e+01, 2.01997960e+01, 1.80704635e+01,\n", + " 1.75964388e+01, 1.75702323e+01, 1.74881473e+01, 1.73655117e+01,\n", + " 1.71348214e+01, 1.67030677e+01, 1.60445292e+01, 1.56427595e+01,\n", + " 1.41032720e+01, 1.32214310e+01, 1.12234703e+01, 1.00753042e+01,\n", + " 9.77385490e+00, 8.38767478e+00, 5.98888204e+00, 5.07176999e+00,\n", + " 4.55778959e+00, 3.33207403e+00, 3.18185142e+00, 2.92450297e+00,\n", + " 1.00007239e+00, 9.85956648e-01, 5.31586580e-01, 8.30069332e-02,\n", + " 6.32816968e-02]])" ] }, - "execution_count": 43, + "execution_count": 242, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_reps[0]" ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 233, "metadata": {}, "outputs": [], "source": [ "target_labels = [t.split(\"/\")[-1].split(\".xyz\")[0] for t in target_sdfs]" ] }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 234, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", - " return array(a, dtype, copy=False, order=order, subok=True)\n" - ] - } - ], + "outputs": [], "source": [ "np.savez(\"target_aCM_data.npz\", \n", " target_labels=target_labels, \n", " target_reps=target_reps, \n", " target_ncharges=ncharges_list,)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 } diff --git a/GetaCMAmons.ipynb b/GetaCMAmons.ipynb index 90c98fa..857d1c5 100644 --- a/GetaCMAmons.ipynb +++ b/GetaCMAmons.ipynb @@ -1,481 +1,599 @@ { "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from glob import glob\n", "import numpy as np" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from rdkit import Chem" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def read_sdf(sdf):\n", " with open(sdf, \"r\") as f:\n", " txt = f.read().rstrip()\n", " return txt" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def get_ncharges_coords(sdf):\n", " mol = Chem.MolFromMolBlock(sdf)\n", " #mol = Chem.AddHs(mol)\n", " # rdkit molobj\n", " ncharges = [atom.GetAtomicNum() for atom in mol.GetAtoms()]\n", " conf = mol.GetConformer()\n", " coords = np.asarray(conf.GetPositions())\n", " return ncharges, coords" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "def cutoff_func(coord_a, coord_b, central_cutoff=1e6, central_decay=-1):\n", - " R_ij = np.linalg.norm(coord_a - coord_b)\n", + "def cutoff_func(R_ij, central_cutoff=4.8, central_decay=1):\n", " if R_ij <= (central_cutoff - central_decay):\n", " func = 1.\n", " elif ((central_cutoff - central_decay) < R_ij) and (R_ij <= (central_cutoff + central_decay)):\n", - " func = 0.5 * (1. + np.cos((np.pi * R_ij - central_cutoff + central_decay)))\n", + " func = 0.5 * (1. + np.cos((np.pi * R_ij - central_cutoff + central_decay)/central_decay))\n", " else:\n", " func = 0.\n", " return func" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ - "def get_atomic_CM(ncharges, coords, max_natoms, central_cutoff=1e6, central_decay=-1):\n", + "def get_atomic_CM(ncharges, coords, max_natoms, central_cutoff=4.8, central_decay=1):\n", " size = int((max_natoms + 1)*max_natoms / 2)\n", " rep = np.zeros((len(ncharges), size))\n", " \n", " # central atom loop\n", " for k in range(len(ncharges)):\n", " M = np.zeros((len(ncharges), len(ncharges)))\n", " for i in range(len(ncharges)):\n", - " f_ik = cutoff_func(coords[i], coords[k])\n", - " for j in range(i):\n", - " if i == j:\n", - " M[i,j] = 0.5 * ncharges[i]**2.4 * f_ik**2\n", - " M[j,i] = M[i,j]\n", - " \n", - " else:\n", - " f_jk = cutoff_func(coords[j], coords[k])\n", - " f_ij = cutoff_func(coords[i], coords[j])\n", - " M[i,j] = (ncharges[i]*ncharges[j]/np.linalg.norm(coords[i]-coords[j]))*f_ik*f_jk*f_ij\n", - " M[j,i] = M[i,j]\n", - " \n", + " R_ik = np.linalg.norm(coords[i]-coords[k])\n", + " # print('R_ik', R_ik)\n", + " f_ik = cutoff_func(R_ik, central_cutoff=central_cutoff,\n", + " central_decay=central_decay)\n", + " for j in range(len(ncharges)):\n", + " if i <=j:\n", + " if i == j:\n", + " M[i,j] = 0.5 * ncharges[i]**2.4 * f_ik**2\n", + " M[j,i] = M[i,j]\n", + "\n", + " else:\n", + " R_jk = np.linalg.norm(coords[j]-coords[k])\n", + " # print('R_jk', R_jk)\n", + " f_jk = cutoff_func(R_jk, central_cutoff=central_cutoff,\n", + " central_decay=central_decay)\n", + " R_ij = np.linalg.norm(coords[i]-coords[j])\n", + " # print('R_ij', R_ij)\n", + " f_ij = cutoff_func(R_ij, central_cutoff=central_cutoff,\n", + " central_decay=central_decay)\n", + " M[i,j] = (ncharges[i]*ncharges[j]/R_ij)*f_ik*f_jk*f_ij\n", + " M[j,i] = M[i,j]\n", + "\n", + "\n", " # concat upper triangular and diagonal\n", - " upper_triang = np.triu(M)\n", - " non_zero_i, non_zero_j = np.nonzero(upper_triang)\n", - " unpadded_rep = upper_triang[non_zero_i, non_zero_j]\n", + " upper_triang = M[np.triu_indices(len(M))]\n", + " s_upper_triang = np.sort(upper_triang)[::-1]\n", + " \n", " # pad to full size\n", - " n_zeros = size - len(unpadded_rep)\n", + " n_zeros = size - len(s_upper_triang)\n", " zeros = np.zeros(n_zeros)\n", - " rep[k] = np.concatenate((unpadded_rep, zeros))\n", - " \n", + " rep[k] = np.concatenate((s_upper_triang, zeros))\n", + "\n", " return rep" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['targets/qm9.sdf', 'targets/vitc.sdf', 'targets/vitd.sdf']" ] }, - "execution_count": 10, + "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_sdfs = sorted(glob(\"targets/*.sdf\"))\n", "target_sdfs" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "qm9_amons_files = sorted(glob(\"amons-qm9/*.sdf\"))" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "qm9_amons_sdfs = [read_sdf(x) for x in qm9_amons_files]" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "conf_data = [get_ncharges_coords(x) for x in qm9_amons_sdfs]" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "ncharges_list, coords_list = zip(*conf_data)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "qm9_ncharges = ncharges_list" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([8],\n", + " [6, 6],\n", + " [6, 7],\n", + " [8, 6],\n", + " [6, 8, 8],\n", + " [6, 8, 7],\n", + " [6, 6, 7, 7],\n", + " [8, 6, 6, 7],\n", + " [8, 6, 6, 6],\n", + " [6, 7, 6, 8],\n", + " [6, 7, 6, 8, 8],\n", + " [6, 8, 8, 7, 6],\n", + " [8, 6, 6, 7, 7, 6],\n", + " [8, 6, 6, 7, 6, 8],\n", + " [8, 6, 6, 7, 6, 8, 8],\n", + " [6, 7, 6, 8, 8, 7, 6])" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qm9_ncharges" + ] + }, + { + "cell_type": "code", + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "qm9_reps = [np.array(get_atomic_CM(np.array(ncharges_list[i]),\n", " np.array(coords_list[i]), \n", " max_natoms=9))\n", " for i in range(len(ncharges_list))]" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 74, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " \"\"\"Entry point for launching an IPython kernel.\n" ] } ], "source": [ "qm9_reps = np.array(qm9_reps)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 45)" ] }, - "execution_count": 18, + "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qm9_reps[0].shape" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[73.51669472, 73.51669472, 53.3587074 , 53.3587074 , 39.59747102,\n", + " 38.24346354, 36.8581052 , 36.8581052 , 36.8581052 , 35.01513065,\n", + " 32.97763233, 32.88982286, 30.62177746, 28.63343371, 24.38002738,\n", + " 24.32694184, 23.45775915, 21.04930695, 17.96633265, 17.46140502,\n", + " 17.44163762, 17.36912369, 17.26721941, 15.88104653, 15.79556841,\n", + " 14.00224752, 13.37087415, 2.00479248, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ],\n", + " [73.51669472, 73.51669472, 53.3587074 , 53.3587074 , 39.59747102,\n", + " 38.24346354, 36.8581052 , 36.8581052 , 36.8581052 , 35.01513065,\n", + " 32.97763233, 32.88982286, 30.62177746, 28.63343371, 24.38002738,\n", + " 24.32694184, 23.45775915, 21.04930695, 17.96633265, 17.46140502,\n", + " 17.44163762, 17.36912369, 17.26721941, 15.88104653, 15.79556841,\n", + " 14.00224752, 13.37087415, 2.00479248, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ],\n", + " [73.51669472, 73.51669472, 53.3587074 , 53.3587074 , 39.59747102,\n", + " 38.24346354, 36.8581052 , 36.8581052 , 36.8581052 , 35.01513065,\n", + " 32.97763233, 32.88982286, 30.62177746, 28.63343371, 24.38002738,\n", + " 24.32694184, 23.45775915, 21.04930695, 17.96633265, 17.46140502,\n", + " 17.44163762, 17.36912369, 17.26721941, 15.88104653, 15.79556841,\n", + " 14.00224752, 13.37087415, 2.00479248, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ],\n", + " [73.51669472, 73.51669472, 53.3587074 , 53.3587074 , 39.59747102,\n", + " 38.24346354, 36.8581052 , 36.8581052 , 35.01513065, 32.88982286,\n", + " 30.62177746, 28.63343371, 24.32694184, 23.45775915, 17.96633265,\n", + " 17.46140502, 17.36912369, 17.26721941, 15.88104653, 15.79556841,\n", + " 14.00224752, 5.43166876, 4.01557734, 3.46698216, 2.87277137,\n", + " 2.20228544, 0.99990932, 0.33020468, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ],\n", + " [73.51669472, 73.51669472, 53.3587074 , 53.3587074 , 39.59747102,\n", + " 38.24346354, 36.8581052 , 36.8581052 , 36.8581052 , 35.01513065,\n", + " 32.97763233, 32.88982286, 30.62177746, 28.63343371, 24.38002738,\n", + " 24.32694184, 23.45775915, 21.04930695, 17.96633265, 17.46140502,\n", + " 17.44163762, 17.36912369, 17.26721941, 15.88104653, 15.79556841,\n", + " 14.00224752, 13.37087415, 2.00479248, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ],\n", + " [73.51669472, 73.51669472, 53.3587074 , 53.3587074 , 39.59747102,\n", + " 38.24346354, 36.8581052 , 36.8581052 , 36.8581052 , 35.01513065,\n", + " 32.97763233, 32.88982286, 30.62177746, 28.63343371, 24.38002738,\n", + " 24.32694184, 23.45775915, 21.04930695, 17.96633265, 17.46140502,\n", + " 17.44163762, 17.36912369, 17.26721941, 15.88104653, 15.79556841,\n", + " 14.00224752, 13.37087415, 2.00479248, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ],\n", + " [73.51669472, 53.3587074 , 53.3587074 , 39.59747102, 36.8581052 ,\n", + " 36.8581052 , 36.8581052 , 35.01513065, 32.97763233, 32.88982286,\n", + " 30.62177746, 24.38002738, 23.45775915, 21.04930695, 17.96633265,\n", + " 17.46140502, 17.44163762, 17.36912369, 17.26721941, 15.88104653,\n", + " 13.37087415, 6.29899151, 4.71614595, 4.00683374, 2.60165116,\n", + " 2.30627747, 1.99440606, 0.33020468, 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ]])" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qm9_reps[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "qm9_amons_labels = [t.split(\"/\")[-1].split(\".sdf\")[0] for t in qm9_amons_files]" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "vitc_amons_files = sorted(glob(\"amons-vitc/*.sdf\"))" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "vitc_amons_sdfs = [read_sdf(x) for x in vitc_amons_files]" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "conf_data = [get_ncharges_coords(x) for x in vitc_amons_sdfs]" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "ncharges_list, coords_list = zip(*conf_data)" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "vitc_ncharges = ncharges_list" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "vitc_reps = [np.array(get_atomic_CM(np.array(ncharges_list[i]), np.array(coords_list[i]), \n", " max_natoms=12)) for i in \n", " range(len(ncharges_list))]" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 84, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " \"\"\"Entry point for launching an IPython kernel.\n" ] } ], "source": [ "vitc_reps = np.array(vitc_reps)" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "vitc_amons_labels = [t.split(\"/\")[-1].split(\".sdf\")[0] for t in vitc_amons_files]" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "vitd_amons_files = sorted(glob(\"amons-vitd/*.sdf\"))" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "vitd_amons_sdfs = [read_sdf(x) for x in vitd_amons_files]" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "conf_data = [get_ncharges_coords(x) for x in vitd_amons_sdfs]" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "ncharges_list, coords_list = zip(*conf_data)" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "vitd_ncharges = ncharges_list" ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "vitd_reps = [np.array(get_atomic_CM(np.array(ncharges_list[i]), np.array(coords_list[i]),\n", " max_natoms=28))\n", " for i in range(len(ncharges_list))]" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 92, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " \"\"\"Entry point for launching an IPython kernel.\n" ] } ], "source": [ "vitd_reps = np.array(vitd_reps)" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "vitd_amons_labels = [t.split(\"/\")[-1].split(\".sdf\")[0] for t in vitd_amons_files]" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "# np save " ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 95, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", - " return array(a, dtype, copy=False, order=order, subok=True)\n" - ] - } - ], + "outputs": [], "source": [ "np.savez(\"amons_aCM_data.npz\", \n", " vitd_amons_labels=vitd_amons_labels,\n", " vitc_amons_labels=vitc_amons_labels,\n", " qm9_amons_labels=qm9_amons_labels,\n", " vitd_amons_ncharges=vitd_ncharges,\n", " vitc_amons_ncharges=vitc_ncharges,\n", " qm9_amons_ncharges=qm9_ncharges,\n", " vitd_amons_reps=vitd_reps,\n", " vitc_amons_reps=vitc_reps,\n", " qm9_amons_reps=qm9_reps)" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 406)" ] }, - "execution_count": 38, + "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vitd_reps[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 } diff --git a/amons_aCM_data.npz b/amons_aCM_data.npz index 92d1760..564bdb0 100644 Binary files a/amons_aCM_data.npz and b/amons_aCM_data.npz differ diff --git a/onepass.py b/onepass.py index 163f47c..ca94935 100644 --- a/onepass.py +++ b/onepass.py @@ -1,165 +1,165 @@ import numpy as np import timeit import gurobipy as gp from gurobipy import GRB def addvariables(Z): upperbounds=[] I=[] J=[] for M in database_indices: CM=data[targetname+"_amons_ncharges"][M] m=len(CM) I=I+[(i,j,M,G) for G in range(maxduplicates) for i in range(m) for j in range(n) if CM[i] == CT[j]] # if condition excludes j; i always takes all m values J=J+[(M,G) for G in range(maxduplicates)] x=Z.addVars(I, vtype=GRB.BINARY) y=Z.addVars(J, vtype=GRB.BINARY) print("Variables added.") return x,I,y def addconstraints(Z,x,I,y): # bijection into [n] Z.addConstrs(x.sum('*',j,'*', '*') == 1 for j in range(n)) for M in database_indices: CM=data[targetname+"_amons_ncharges"][M] m=len(CM) # each i of each group is used at most once Z.addConstrs(x.sum(i,'*',M,G) <= 1 for i in range(m) for G in range(maxduplicates)) # y[M,G] = OR gate of the x[i,j,M,G] for each (M,G) Z.addConstrs(y[M,G] >= x[v] for G in range(maxduplicates) for v in I if v[2:]==(M,G)) Z.addConstrs(y[M,G] <= x.sum('*','*',M,G) for G in range(maxduplicates)) print("Constraints added.") return 0 # objective value should then be square rooted in the end (doesn't change optimality) def setobjective(Z,x,I,y): print("Constructing objective function... ") key=0 if(representation==0): # Coulomb case expr=gp.QuadExpr() T=targetdata['target_CMs'][target_index] for k in range(n): for l in range(n): expr += T[k,l]**2 for M in database_indices: key=key+1 Mol=data[targetname+"_amons_CMs"][M] m=len(Mol) for G in range(maxduplicates): for (i,k) in [v[:2] for v in I if v[2:]==(M,G)]: for (j,l) in [v[:2] for v in I if v[2:]==(M,G)]: expr += (Mol[i,j]**2 - 2*T[k,l]*Mol[i,j])*x[i,k,M,G]*x[j,l,M,G] expr += y[M,G]*m*penaltyconst print(key, " / ", size_database) expr=expr-n*penaltyconst else: #SLATM case expr=gp.LinExpr() T=targetdata["target_reps"][target_index] for M in database_indices: key=key+1 Mol=data[targetname+"_amons_reps"][M] m=len(Mol) for G in range(maxduplicates): for (i,j) in [v[:2] for v in I if v[2:]==(M,G)]: C=np.linalg.norm(Mol[i]-T[j])**2 expr += C*x[i,j,M,G] expr += y[M,G]*m*penaltyconst print(key, " / ", size_database) expr=expr-n*penaltyconst Z.setObjective(expr, GRB.MINIMIZE) print("Objective function set.") return 0 # prints mappings of positions (indices+1) of each molecule to positions inside target def print_sols(Z, x, I, y): SolCount=Z.SolCount print("Target has size", n) print("Using representation", repname) for solnb in range(SolCount): print() print("--------------------------------") Z.setParam("SolutionNumber",solnb) print("Solution number", solnb+1, ", objective value with size penalty", (Z.PoolObjVal)) for M in database_indices: groups=[] for G in range(maxduplicates): if np.rint(y[M,G].Xn) == 1: groups.append(G) amount_picked=len(groups) for k in range(amount_picked): G=groups[k] m=len(data[targetname+"_amons_ncharges"][M]) label=data[targetname+"_amons_labels"][M] if k==0: print("Molecule", label, "has been picked", amount_picked, "time(s) ( size", m, ", used", sum([x[v].Xn for v in I if v[2]==M]), ")") print(k+1, end=": ") for (i,j) in [v[:2] for v in I if v[2:]==(M,G) and np.rint(x[v].Xn)==1]: print(i+1, "->", j+1, end=", ") print() def main(): # construction of the model start=timeit.default_timer() Z = gp.Model() Z.setParam('OutputFlag',1) x,I,y=addvariables(Z) addconstraints(Z,x,I,y) setobjective(Z,x,I,y) stop=timeit.default_timer() print("Model setup: ", stop-start, "s") # model parameters # PoolSearchMode 1/2 forces to fill the solution pool. 2 finds the best solutions. # Set to 1 because of duplicating solutions which differ by 1e-9 and are seen as different. Z.setParam("PoolSearchMode", 1) # these prevent non integral values although some solutions are still duplicating -- to fix? Z.setParam("IntFeasTol", 1e-9) Z.setParam("IntegralityFocus", 1) Z.setParam("TimeLimit", timelimit) Z.setParam("PoolSolutions", numbersolutions) # optimization print("------------") print("Optimization") print("------------") Z.optimize() print("------------") print() print("Optimization runtime: ", Z.RunTime, "s") if(Z.status == 3): print("Model was proven to be infeasible.") return 1 print_sols(Z,x,I,y) return 0 # modifiable global settings target_index=1 # 0, 1, or 2 for qm9, vitc, or vitd. maxduplicates=2 # number of possible copies of each molecule of the database -timelimit=120 # in seconds (not counting setup) -numbersolutions=10 # size of solution pool -representation=2 # 0 for Coulomb Matrix (CM), 1 for SLATM, 2 for aCM, 3 for SOAP, 4 for FCHL -penaltyconst=[1,1,10000,1,1][representation] # constant in front of size penalty +timelimit=360 # in seconds (not counting setup) +numbersolutions=5 # size of solution pool +representation=1 # 0 for Coulomb Matrix (CM), 1 for SLATM, 2 for aCM, 3 for SOAP, 4 for FCHL +penaltyconst=[1,1,1,1,1][representation] # constant in front of size penalty # global constants repname=["CM", "SLATM", "aCM", "SOAP", "FCHL"][representation] dataname="amons_"+repname+"_data.npz" data=np.load(dataname, allow_pickle=True) targetdataname="target_"+repname+"_data.npz" targetdata=np.load(targetdataname, allow_pickle=True) CT=targetdata['target_ncharges'][target_index] n=len(CT) targetname=["qm9", "vitc", "vitd"][target_index] size_database=len(data[targetname+"_amons_labels"]) database_indices=range(size_database) main() diff --git a/target_aCM_data.npz b/target_aCM_data.npz index 11ffe2d..2c89505 100644 Binary files a/target_aCM_data.npz and b/target_aCM_data.npz differ