{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from glob import glob\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from rdkit import Chem" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "target_xyzs = sorted(glob(\"targets/*.xyz\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def read_sdf(sdf):\n", " with open(sdf, \"r\") as f:\n", " txt = f.read().rstrip()\n", " return txt" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_ncharges_coords(sdf):\n", " mol = Chem.MolFromMolBlock(sdf)\n", " #mol = Chem.AddHs(mol)\n", " # rdkit molobj\n", " ncharges = [atom.GetAtomicNum() for atom in mol.GetAtoms()]\n", " conf = mol.GetConformer()\n", " coords = np.asarray(conf.GetPositions())\n", " return ncharges, coords" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def cutoff_func(R_ij, central_cutoff=4.8, central_decay=1):\n", " if R_ij <= (central_cutoff - central_decay):\n", " func = 1.\n", " elif ((central_cutoff - central_decay) < R_ij) and (R_ij <= (central_cutoff + central_decay)):\n", " func = 0.5 * (1. + np.cos((np.pi * R_ij - central_cutoff + central_decay)/central_decay))\n", " else:\n", " func = 0.\n", " return func" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def get_atomic_CM(ncharges, coords, max_natoms, central_cutoff=4.8, central_decay=1):\n", " size = int((max_natoms + 1)*max_natoms / 2)\n", " rep = np.zeros((len(ncharges), size))\n", " \n", " # central atom loop\n", " for k in range(len(ncharges)):\n", " M = np.zeros((len(ncharges), len(ncharges)))\n", " for i in range(len(ncharges)):\n", " R_ik = np.linalg.norm(coords[i]-coords[k])\n", " # print('R_ik', R_ik)\n", " f_ik = cutoff_func(R_ik, central_cutoff=central_cutoff,\n", " central_decay=central_decay)\n", " for j in range(len(ncharges)):\n", " if i <=j:\n", " if i == j:\n", " M[i,j] = 0.5 * ncharges[i]**2.4 * f_ik**2\n", " M[j,i] = M[i,j]\n", "\n", " else:\n", " R_jk = np.linalg.norm(coords[j]-coords[k])\n", " # print('R_jk', R_jk)\n", " f_jk = cutoff_func(R_jk, central_cutoff=central_cutoff,\n", " central_decay=central_decay)\n", " R_ij = np.linalg.norm(coords[i]-coords[j])\n", " # print('R_ij', R_ij)\n", " f_ij = cutoff_func(R_ij, central_cutoff=central_cutoff,\n", " central_decay=central_decay)\n", " M[i,j] = (ncharges[i]*ncharges[j]/R_ij)*f_ik*f_jk*f_ij\n", " M[j,i] = M[i,j]\n", "\n", "\n", " # concat upper triangular and diagonal\n", " upper_triang = M[np.triu_indices(len(M))]\n", " s_upper_triang = np.sort(upper_triang)[::-1]\n", " \n", " # pad to full size\n", " n_zeros = size - len(s_upper_triang)\n", " zeros = np.zeros(n_zeros)\n", " rep[k] = np.concatenate((s_upper_triang, zeros))\n", "\n", " return rep" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['targets/qm9.sdf', 'targets/vitc.sdf', 'targets/vitd.sdf']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_files = sorted(glob(\"targets/*.sdf\"))\n", "target_files" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "target_sdfs = [read_sdf(x) for x in target_files]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "conf_data = [get_ncharges_coords(x) for x in target_sdfs]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "ncharges_list, coords_list = zip(*conf_data)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[9, 12, 28]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sizes = [len(x) for x in ncharges_list]\n", "sizes" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/ipykernel_launcher.py:4: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " after removing the cwd from sys.path.\n" ] } ], "source": [ "target_reps = np.array(\n", "[np.array(get_atomic_CM(np.array(ncharges_list[i]), np.array(coords_list[i]),\n", " max_natoms=sizes[i]))\n", "for i in range(len(ncharges_list))])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[7.35166947e+01, 5.33587074e+01, 4.35877623e+01, 3.98926848e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.25783598e+01,\n", " 2.49367773e+01, 2.38944684e+01, 2.01997960e+01, 1.79395692e+01,\n", " 1.75702323e+01, 1.74881473e+01, 1.67030677e+01, 1.63465866e+01,\n", " 1.60445292e+01, 1.41032720e+01, 1.39142112e+01, 1.37269310e+01,\n", " 1.33142263e+01, 1.16565056e+01, 1.05343157e+01, 9.57815448e+00,\n", " 7.75795874e+00, 7.73149910e+00, 7.25911544e+00, 6.94206566e+00,\n", " 6.07433646e+00, 5.86750965e+00, 4.33751746e+00, 2.07719952e+00,\n", " 2.04788044e+00, 2.02269137e+00, 1.10413147e+00, 2.01239801e-01,\n", " 1.93433788e-01, 1.08491628e-01, 1.07067836e-01, 6.91988063e-02,\n", " 4.87267466e-02, 4.68841852e-02, 4.26710041e-02, 2.02837470e-03,\n", " 5.11782947e-04],\n", " [7.35166947e+01, 5.33587074e+01, 5.33587074e+01, 3.98926848e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.26394203e+01, 3.25783598e+01, 3.00013353e+01, 2.49367773e+01,\n", " 2.38944684e+01, 2.01997960e+01, 1.86528773e+01, 1.75964388e+01,\n", " 1.75702323e+01, 1.74881473e+01, 1.73655117e+01, 1.71348214e+01,\n", " 1.67030677e+01, 1.60445292e+01, 1.57357844e+01, 1.56427595e+01,\n", " 1.41032720e+01, 1.32214310e+01, 1.12788235e+01, 1.12234703e+01,\n", " 9.77385490e+00, 7.41001169e+00, 6.50363736e+00, 4.55778959e+00,\n", " 2.22400463e+00, 2.19261343e+00, 1.18216544e+00, 2.23588601e-01,\n", " 1.87189633e-01, 1.27850012e-01, 1.16497296e-01, 9.91624747e-02,\n", " 8.30069332e-02, 7.30849704e-02, 5.52886812e-02, 2.21381896e-03,\n", " 3.47261135e-04],\n", " [7.35166947e+01, 7.35166947e+01, 7.35166947e+01, 5.33587074e+01,\n", " 5.33587074e+01, 4.07447439e+01, 4.03175954e+01, 3.98926848e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.41117285e+01, 3.26394203e+01, 3.25783598e+01, 3.00013353e+01,\n", " 2.87871584e+01, 2.49367773e+01, 2.43788148e+01, 2.38944684e+01,\n", " 2.32982182e+01, 2.12294030e+01, 2.01997960e+01, 1.80704635e+01,\n", " 1.75964388e+01, 1.75702323e+01, 1.74881473e+01, 1.73655117e+01,\n", " 1.71348214e+01, 1.67030677e+01, 1.60445292e+01, 1.60165023e+01,\n", " 1.56427595e+01, 1.41032720e+01, 1.40574033e+01, 1.32214310e+01,\n", " 1.12234703e+01, 1.00753042e+01, 9.77385490e+00, 4.80711459e+00,\n", " 4.73926353e+00, 4.55778959e+00, 2.55521264e+00, 8.30069332e-02,\n", " 6.32816968e-02],\n", " [7.35166947e+01, 7.35166947e+01, 7.35166947e+01, 5.33587074e+01,\n", " 5.33587074e+01, 4.07447439e+01, 4.03175954e+01, 3.98926848e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.41117285e+01, 3.26394203e+01, 3.25783598e+01, 3.00013353e+01,\n", " 2.87871584e+01, 2.49367773e+01, 2.43788148e+01, 2.38944684e+01,\n", " 2.32982182e+01, 2.12294030e+01, 2.01997960e+01, 1.80704635e+01,\n", " 1.75964388e+01, 1.75702323e+01, 1.74881473e+01, 1.73655117e+01,\n", " 1.71348214e+01, 1.67030677e+01, 1.60445292e+01, 1.60165023e+01,\n", " 1.56427595e+01, 1.41032720e+01, 1.40574033e+01, 1.32214310e+01,\n", " 1.12234703e+01, 1.00753042e+01, 9.77385490e+00, 4.80711459e+00,\n", " 4.73926353e+00, 4.55778959e+00, 2.55521264e+00, 8.30069332e-02,\n", " 6.32816968e-02],\n", " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", " 4.07447439e+01, 4.03175954e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", " 3.25783598e+01, 3.00013353e+01, 2.87871584e+01, 2.49367773e+01,\n", " 2.43788148e+01, 2.38944684e+01, 2.32982182e+01, 2.12294030e+01,\n", " 1.80704635e+01, 1.77038864e+01, 1.75964388e+01, 1.75702323e+01,\n", " 1.74881473e+01, 1.73655117e+01, 1.71348214e+01, 1.60165023e+01,\n", " 1.56427595e+01, 1.44789368e+01, 1.41032720e+01, 1.40574033e+01,\n", " 1.32214310e+01, 1.12234703e+01, 9.77385490e+00, 8.96442281e+00,\n", " 7.41261748e+00, 7.12036615e+00, 4.73926353e+00, 4.47129696e+00,\n", " 2.55521264e+00, 2.13333876e+00, 2.02269137e+00, 6.32816968e-02,\n", " 3.68374634e-02],\n", " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", " 4.07447439e+01, 4.03175954e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.41117285e+01, 3.25783598e+01, 3.00013353e+01, 2.87871584e+01,\n", " 2.43788148e+01, 2.32982182e+01, 1.80704635e+01, 1.75964388e+01,\n", " 1.73655117e+01, 1.71348214e+01, 1.60165023e+01, 1.56427595e+01,\n", " 1.40574033e+01, 1.37269310e+01, 1.10547413e+01, 8.72852223e+00,\n", " 8.12884262e+00, 7.97513819e+00, 7.88924475e+00, 6.93299229e+00,\n", " 6.79030682e+00, 5.19252233e+00, 5.18784853e+00, 4.52185987e+00,\n", " 4.41656617e+00, 4.35363388e+00, 3.63823513e+00, 2.75058724e+00,\n", " 2.19261343e+00, 2.07719952e+00, 1.96946384e+00, 1.59524329e+00,\n", " 1.50154012e+00, 1.35743192e+00, 5.31586580e-01, 3.58680782e-02,\n", " 2.92771857e-02],\n", " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", " 4.35877623e+01, 4.07447439e+01, 4.03175954e+01, 3.68581052e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", " 3.25783598e+01, 3.00013353e+01, 2.87871584e+01, 2.49367773e+01,\n", " 2.43788148e+01, 2.32982182e+01, 2.12294030e+01, 1.80704635e+01,\n", " 1.75964388e+01, 1.74881473e+01, 1.73655117e+01, 1.71348214e+01,\n", " 1.60165023e+01, 1.56427595e+01, 1.55537919e+01, 1.40574033e+01,\n", " 1.32214310e+01, 1.28613199e+01, 1.23542469e+01, 7.75795874e+00,\n", " 3.70146609e+00, 3.50948647e+00, 2.55521264e+00, 1.68562371e-01,\n", " 1.31121962e-01, 9.64174338e-02, 7.73923346e-02, 6.39151288e-02,\n", " 6.15892940e-02, 5.36344649e-02, 2.60069201e-02, 1.10991350e-03,\n", " 3.47261135e-04],\n", " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", " 4.07447439e+01, 4.03175954e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", " 3.25783598e+01, 3.00013353e+01, 2.87871584e+01, 2.49367773e+01,\n", " 2.43788148e+01, 2.38944684e+01, 2.32982182e+01, 2.12294030e+01,\n", " 1.80704635e+01, 1.75964388e+01, 1.75702323e+01, 1.74881473e+01,\n", " 1.73655117e+01, 1.71348214e+01, 1.60165023e+01, 1.56427595e+01,\n", " 1.41032720e+01, 1.40574033e+01, 1.32214310e+01, 1.12234703e+01,\n", " 9.77385490e+00, 4.73926353e+00, 2.55521264e+00, 2.45960127e-01,\n", " 1.24542743e-01, 1.02983508e-01, 9.89232601e-02, 6.32816968e-02,\n", " 6.21197370e-02, 2.96384793e-02, 2.81012549e-02, 2.79465922e-03,\n", " 5.11782947e-04],\n", " [7.35166947e+01, 7.35166947e+01, 5.33587074e+01, 5.33587074e+01,\n", " 4.07447439e+01, 3.98926848e+01, 3.68581052e+01, 3.68581052e+01,\n", " 3.68581052e+01, 3.68581052e+01, 3.41117285e+01, 3.26394203e+01,\n", " 3.25783598e+01, 3.00013353e+01, 2.49367773e+01, 2.38944684e+01,\n", " 2.32982182e+01, 2.12294030e+01, 2.01997960e+01, 1.80704635e+01,\n", " 1.75964388e+01, 1.75702323e+01, 1.74881473e+01, 1.73655117e+01,\n", " 1.71348214e+01, 1.67030677e+01, 1.60445292e+01, 1.56427595e+01,\n", " 1.41032720e+01, 1.32214310e+01, 1.12234703e+01, 1.00753042e+01,\n", " 9.77385490e+00, 8.38767478e+00, 5.98888204e+00, 5.07176999e+00,\n", " 4.55778959e+00, 3.33207403e+00, 3.18185142e+00, 2.92450297e+00,\n", " 1.00007239e+00, 9.85956648e-01, 5.31586580e-01, 8.30069332e-02,\n", " 6.32816968e-02]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_reps[0]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "target_labels = [t.split(\"/\")[-1].split(\".xyz\")[0] for t in target_sdfs]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/puck/anaconda3/envs/rdkit/lib/python3.7/site-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " return array(a, dtype, copy=False, order=order, subok=True)\n" ] } ], "source": [ "np.savez(\"target_aCM_data.npz\", \n", " target_labels=target_labels, \n", " target_reps=target_reps, \n", " target_ncharges=ncharges_list,)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }