diff --git a/onepass.py b/onepass.py new file mode 100644 index 0000000..270be64 --- /dev/null +++ b/onepass.py @@ -0,0 +1,71 @@ +import numpy as np +import timeit +import gurobipy as gp +from gurobipy import GRB + +data=np.load("data.npz", allow_pickle=True) +Z = gp.Model() +Z.setParam('OutputFlag', 0) +target_index=2 +CT=data['target_ncharges'][target_index] +T=data['target_CMs'][target_index] +size_database=3 #len(data['database_ncharges']) +n=len(data['target_ncharges'][target_index]) + +start=timeit.default_timer() +I=[] +for M in range(size_database): + print(M, " / ", size_database) + m=len(data['database_ncharges'][M]) + I=I+[(i,j,M) for i in range(m) for j in range(n)] + +x=Z.addVars(I, vtype=GRB.BINARY) + +# injection into [n], ideally set to equality otherwise the 0 solution is possible +Z.addConstrs(x.sum('*',j,'*') <= 1 for j in range(n)) + +# dummy variables y associated to each molecule M (1 if M is taken, 0 otherwise) +y=Z.addVars(range(size_database), vtype=GRB.BINARY) + +# additional constraints: take whole molecule or leave it out, and matching charges +for M in range(size_database): + CM=data['database_ncharges'][M] + m=len(CM) + Z.addConstr(y[M]==1) # temporarily force taking every molecule to ignore 0 solution + for i in range(m): + Z.addConstr(x.sum(i,'*',M)==y[M]) + for j in range(n): + if(CM[i] != CT[j]): + Z.addConstr(x[i,j,M]==0) + +expr = 2*x[0,0,0]*x[0,0,0] +expr.clear() + +for M in range(size_database): + Mol=data['database_CMs'][M] + m=len(Mol) + for i in range(m): + for j in range(m): + for k in range(n): + for l in range(n): + expr.add(x[i,k,M] * x[j,l,M], (T[k,l]-Mol[i,j])**2) +Z.setObjective(expr, GRB.MINIMIZE) +stop=timeit.default_timer() +print("Model setup: ", stop-start, "s") + +Z.optimize() +print("Optimization runtime: ", Z.RunTime, "s") + +assert Z.status == 2 + +for M in range(size_database): + if y[M].X == 1: + print("Molecule", data['database_labels'][M], "has been picked") + +matched_targets=[] +for i in I: + x[i] = x[i].X + if x[i]==1: + matched_targets.append(i[1]) + +print(matched_targets)