diff --git a/onepass.py b/onepass.py index 047de1a..cbee7b4 100644 --- a/onepass.py +++ b/onepass.py @@ -1,157 +1,186 @@ import numpy as np import timeit import gurobipy as gp from gurobipy import GRB def addvariables(Z): upperbounds=[] I=[] + J=[] for M in database_indices: m=len(data['database_ncharges'][M]) + #adjM=connectivity_data['frag_adj_matrices'][M] I=I+[(i,j,M) for i in range(m) for j in range(n)] + #J=J+[(i,j,M) for i in range(m) for j in range(i+1,m) if adjM[i,j]] # for later + J=J+[(i,j,M) for i in range(m) for j in range(i+1,m)] upperbounds.append(int(n/m)) x=Z.addVars(I, vtype=GRB.BINARY) + e=Z.addVars(J, vtype=GRB.BINARY) # dummy variables y associated to number of times molecule M is picked - y=Z.addVars(database_indices, vtype="I", lb=0, ub=upperbounds) + # temporarily set binary + y=Z.addVars(database_indices, vtype="B", lb=0, ub=upperbounds) + print("Variables added.") - return x,y,I + return x,y,e -def addconstraints(Z,x,y): +def addconstraints(Z,x,y,e): # bijection into [n] Z.addConstrs(x.sum('*',j,'*') == 1 for j in range(n)) # additional constraints: take whole molecule or leave it out, and matching charges for M in database_indices: CM=data['database_ncharges'][M] m=len(CM) # the number of indices of M used (counting multiple use) is at least mol_percent of its size Z.addConstr(x.sum('*','*',M) >= mol_percent*m*y[M]) # each index in M is used at most y[M] times Z.addConstrs(x.sum(i,'*',M) <= y[M] for i in range(m)) - # ignore incompatible charges and atoms of charge 1 + #Z.addConstr(t.sum('*',M) == y[M]) + #Z.addConstrs(t[i,M] <= x.sum(i,'*',M) for i in range(m)) + + # ignore incompatible charges for i in range(m): for j in range(n): if(CM[i] != CT[j]): Z.addConstr(x[i,j,M]==0) - + # forces #edges >= #nodes - 1, which must be true for any connected graph + adjM=connectivity_data['frag_adj_matrices'][M] + expr=gp.LinExpr() + for i in range(m): + for j in range(i+1,m): + if adjM[i,j]: + Z.addConstr(e[i,j,M] <= x.sum(i,'*',M)) + Z.addConstr(e[i,j,M] <= x.sum(j,'*',M)) + Z.addConstr(e[i,j,M] >= x.sum(i,'*',M)+x.sum(j,'*',M) - 1) + #Z.addGenConstrAnd(e[i,j,M], [x.sum(i,'*',M), x.sum(j,'*',M)]) #this doesn't work? + expr += e[i,j,M] + for k in range(n): + expr -= x[i,k,M] + Z.addConstr(expr+1 >= 0) + print("Constraints added.") return 0 def setobjective(Z,x,y): expr=gp.QuadExpr() print("Constructing objective function... ") key=0 for M in database_indices: key=key+1 Mol=data['database_CMs'][M] m=len(Mol) - expr.addTerms(-m, y[M]) + expr.addTerms(-1, y[M]) + adjM=connectivity_data['frag_adj_matrices'][M] for i in range(m): for j in range(m): for k in range(n): for l in range(k,n): - expr.add(x[i,k,M] * x[j,l,M], np.abs(T[k,l]-Mol[i,j])**2) + expr.add(x[i,k,M] * x[j,l,M], np.abs(T[k,l]-Mol[i,j])**2) + Z.addConstr(x[i,k,M]+x[j,l,M]-1 <= (adjM[i,j] == adjT[k,l])) print(key, " / ", size_database) Z.setObjective(expr, GRB.MINIMIZE) print("Objective function set.") return 0 # prints mappings of positions (indices+1) of each molecule (before preprocess) to positions inside target (before preprocess, but the hydrogens are at the end anyway) def print_sols(Z, x, y): SolCount=Z.SolCount for solnb in range(SolCount): print() print("--------------------------------") Z.setParam("SolutionNumber",solnb) print("Solution number", solnb+1, ", objective value", Z.PoolObjVal) for M in database_indices: amount_picked=int(np.rint(y[M].Xn)) if amount_picked != 0: m=len(data['database_ncharges'][M]) U=np.zeros((m,amount_picked)) # constructing U for i in range(m): k=0 for j in range(n): if x[i,j,M].Xn==1 and sum(U[:,k]!=0) < mol_percent*m: U[i,k]=j+1 k=k+1 # reading U for k in range(amount_picked): if np.any(U[:,k] != 0): if k==0: print("Molecule", data['database_labels'][M], "has been picked", amount_picked, "time(s) ( size", len(data['database_ncharges'][M]), ", used", sum([x[i,j,M].Xn for i in range(m) for j in range(n)]), ")") print(k+1, end=": ") for i in range(m): if U[i,k]!=0: print(oldindex(i,M)+1, "->", U[i,k], end=", ") print("used", sum(U[:,k]!=0)) # converts new index (with hydrogens removed) to old index. def oldindex(i, M): k=0 notones=0 CM=olddata['database_ncharges'][M] while notones