diff --git a/src/server/ch/epfl/lca1/medco/StandardQuery.java b/src/server/ch/epfl/lca1/medco/StandardQuery.java index 7af3a34..100c55e 100644 --- a/src/server/ch/epfl/lca1/medco/StandardQuery.java +++ b/src/server/ch/epfl/lca1/medco/StandardQuery.java @@ -1,263 +1,270 @@ /* * Copyright (c) 2006-2007 Massachusetts General Hospital * All rights reserved. This program and the accompanying materials * are made available under the terms of the i2b2 Software License v1.0 * which accompanies this distribution. * * Contributors: * Rajesh Kuttan */ package ch.epfl.lca1.medco; import ch.epfl.lca1.medco.i2b2.crc.I2B2CRCCell; import ch.epfl.lca1.medco.i2b2.crc.I2B2QueryRequest; import ch.epfl.lca1.medco.i2b2.crc.I2B2QueryResponse; import ch.epfl.lca1.medco.i2b2.pm.I2B2PMCell; import ch.epfl.lca1.medco.i2b2.pm.UserInformation; import ch.epfl.lca1.medco.unlynx.UnlynxClient; import ch.epfl.lca1.medco.util.Constants; import ch.epfl.lca1.medco.util.Logger; import ch.epfl.lca1.medco.util.Timers; import ch.epfl.lca1.medco.util.exceptions.MedCoError; import ch.epfl.lca1.medco.util.exceptions.MedCoException; import edu.harvard.i2b2.common.exception.I2B2Exception; import edu.harvard.i2b2.common.util.jaxb.JAXBUtilException; import edu.harvard.i2b2.crc.datavo.setfinder.query.PanelType; import edu.harvard.i2b2.crc.datavo.setfinder.query.QueryDefinitionType; import org.javatuples.Pair; import java.io.StringWriter; import java.util.ArrayList; import java.util.List; import java.util.regex.Matcher; // //todo doc: https://github.com/chb/shrine/tree/master/doc /** * Represents a query to MedCo. * From the XML query (in CRC format), parse to extract the sensitive attributes, * make query to CRC for non-sensitive attributes, get the patient set from CRC, * query the cothority with the patient sets and sensitive attributes and answer. * * * everything under that sohuld not use the config!! */ public class StandardQuery { private I2B2QueryRequest queryRequest; private I2B2CRCCell crcCell; private I2B2PMCell pmCell; private UnlynxClient unlynxClient; //int resultMode, String clientPubKey, long timoutSeconds public StandardQuery(I2B2QueryRequest request, String unlynxBinPath, String unlynxGroupFilePath, int unlynxDebugLevel, int unlynxEntryPointIdx, int unlynxProofsFlag, long unlynxTimeoutSeconds, String crcCellUrl, String pmCellUrl) throws I2B2Exception { this.queryRequest = request; unlynxClient = new UnlynxClient(unlynxBinPath, unlynxGroupFilePath, unlynxDebugLevel, unlynxEntryPointIdx, unlynxProofsFlag, unlynxTimeoutSeconds); crcCell = new I2B2CRCCell(crcCellUrl, queryRequest.getMessageHeader()); pmCell = new I2B2PMCell(pmCellUrl, queryRequest.getMessageHeader()); } /** * * @return the query answer in CRC XML format. * @throws JAXBUtilException */ public I2B2QueryResponse executeQuery() throws MedCoException, I2B2Exception { Timers.resetTimers(); Timers.get("overall").start(); // get user information (auth., privacy budget, authorizations, public key) // todo: get and check budget query / user // todo: get user permissions Timers.get("steps").start("User information retrieval"); UserInformation user = pmCell.getUserInformation(queryRequest.getMessageHeader()); if (!user.isAuthenticated()) { Logger.warn("Authentication failed for user " + user.getUsername()); // todo: proper auth failed response return null; } QueryType queryType = QueryType.resolveUserPermission(user.getRoles()); Timers.get("steps").stop(); // retrieve the encrypted query terms Timers.get("steps").start("Query parsing/splitting"); List encryptedQueryItems = extractEncryptedQueryTerms(false, false); Timers.get("steps").stop(); + // intercept test query from SHRINE and bypass unlynx + if (encryptedQueryItems.contains(Constants.CONCEPT_NAME_TEST_FLAG)) { + Logger.info("Intercepted SHRINE status query (" + queryRequest.getQueryName() + ")."); + replaceEncryptedQueryTerms(encryptedQueryItems); + return crcCell.queryRequest(queryRequest); + } + // query unlynx to tag the query terms Timers.get("steps").start("Query tagging"); List taggedItems = unlynxClient.computeDistributedDetTags(queryRequest.getQueryName(), encryptedQueryItems); Timers.addAdditionalTimes(unlynxClient.getLastTimingMeasurements()); Timers.get("steps").stop(); // replace the query terms, query i2b2 with the original clear query terms + the tagged ones Timers.get("steps").start("i2b2 query"); replaceEncryptedQueryTerms(taggedItems); overrideResultOutputTypes(new String[]{"PATIENTSET", "PATIENT_COUNT_XML"}); I2B2QueryResponse i2b2Response = crcCell.queryRequest(queryRequest); Timers.get("steps").stop(); // retrieve the patient set, including the encrypted dummy flags Timers.get("steps").start("i2b2 patient set retrieval"); Pair, List> patientSet = crcCell.queryForPatientSet(i2b2Response.getPatientSetId(), true); Timers.get("steps").stop(); String aggResult; switch (queryType) { case AGGREGATED_PER_SITE: aggResult = unlynxClient.aggregateData(queryRequest.getQueryName(), user.getUserPublicKey(), patientSet.getValue1()); Timers.addAdditionalTimes(unlynxClient.getLastTimingMeasurements()); break; case OBFUSCATED_PER_SITE: case AGGREGATED_TOTAL: default: throw new MedCoError("Query type not supported yet."); } i2b2Response.resetResultInstanceListToEncryptedCountOnly(); Timers.get("overall").stop(); i2b2Response.setQueryResults(user.getUserPublicKey(), aggResult, Timers.generateFullReport()); Logger.info("MedCo query successful (" + queryRequest.getQueryName() + ")."); return i2b2Response; } /** * TODO * No checks on panels are done (i.e. if they contain mixed query types or not) * * @param taggedItems * @throws MedCoException */ private void replaceEncryptedQueryTerms(List taggedItems) throws MedCoException { QueryDefinitionType qd = queryRequest.getQueryDefinition(); int encTermCount = 0; // iter on the panels for (int p = 0; p < qd.getPanel().size(); p++) { PanelType panel = qd.getPanel().get(p); // iter on the items int nbItems = panel.getItem().size(); for (int i = 0; i < nbItems; i++) { // replace encrypted item with its tagged version Matcher medcoKeyMatcher = Constants.REGEX_QUERY_KEY_ENC.matcher(panel.getItem().get(i).getItemKey()); if (medcoKeyMatcher.matches()) { panel.getItem().get(i).setItemKey(Constants.CONCEPT_PATH_TAGGED_PREFIX + taggedItems.get(encTermCount++) + "\\"); } } } // check the provided taggedItems match the number of encrypted terms if (encTermCount != taggedItems.size()) { Logger.warn("Mismatch in provided number of tagged items (" + taggedItems.size() + ") and number of encrypted items in query (" + encTermCount + ")"); } } private void overrideResultOutputTypes(String[] outputTypes) throws MedCoException { queryRequest.setOutputTypes(outputTypes); } /** * Extract from the i2b2 query the sensitive / encrypted items recognized by the prefix defined in {@link Constants}. * Accepts only panels fully clear or encrypted, i.e. no mix is allowed. *

* The predicate, if returned, has the following format: * (exists(v0, r) || exists(v1, r)) && (exists(v2, r) || exists(v3, r)) && exists(v4, r) * * @param removePanels removes the encrypted panels when encountered if true * @param getPredicate adds at the end of the returned list the corresponding predicate if true * @return the list of encrypted query terms and optionally the corresponding predicate * @throws MedCoException if a panel contains mixed clear and encrypted query terms */ private List extractEncryptedQueryTerms(boolean removePanels, boolean getPredicate) throws MedCoException { // todo: handle cases: only clear no encrypt / only encrypt no clear // todo: must be modified if invertion implementation QueryDefinitionType qd = queryRequest.getQueryDefinition(); StringWriter predicateSw = new StringWriter(); List extractedItems = new ArrayList<>(); int encTermCount = 0; // iter on the panels for (int p = 0; p < qd.getPanel().size(); p++) { boolean panelIsEnc = false, panelIsClear = false; PanelType panel = qd.getPanel().get(p); // iter on the items int nbItems = panel.getItem().size(); for (int i = 0; i < nbItems; i++) { // check if item is clear or encrypted, extract and generate predicate if yes Matcher medcoKeyMatcher = Constants.REGEX_QUERY_KEY_ENC.matcher(panel.getItem().get(i).getItemKey()); if (medcoKeyMatcher.matches()) { if (i == 0) { predicateSw.append("("); } extractedItems.add(medcoKeyMatcher.group(1)); predicateSw.append("exists(v" + encTermCount++ + ", r)"); if (i < nbItems - 1) { predicateSw.append(" || "); } else if (i == nbItems - 1) { predicateSw.append(")"); if (p < qd.getPanel().size() - 1) { predicateSw.append(" && "); } } Logger.debug("Extracted item " + extractedItems.get(extractedItems.size() - 1)); panelIsEnc = true; } else { panelIsClear = true; } // enforce that a panel can only be one type if (panelIsClear && panelIsEnc) { throw Logger.error(new MedCoException("Encountered panel with mixed clear and encrypted query terms: not allowed.")); } } // remove panel and log if (panelIsEnc) { if (removePanels) { qd.getPanel().remove(panel); p--; Logger.debug("Removed encrypted panel"); } Logger.debug("Encountered encrypted panel"); } else if (panelIsClear) { Logger.debug("Encountered clear panel"); } else { Logger.warn("Encountered empty panel in query " + qd.getQueryName()); } } String predicate = predicateSw.toString(); Logger.info("Extracted " + extractedItems.size() + " encrypted query terms and generated unlynx predicate with " + encTermCount + " terms: " + predicate + " for query " + queryRequest.getQueryName()); if (getPredicate) { extractedItems.add(predicate); } return extractedItems; } } diff --git a/src/server/ch/epfl/lca1/medco/unlynx/UnlynxClient.java b/src/server/ch/epfl/lca1/medco/unlynx/UnlynxClient.java index acd0e21..a8cfecd 100644 --- a/src/server/ch/epfl/lca1/medco/unlynx/UnlynxClient.java +++ b/src/server/ch/epfl/lca1/medco/unlynx/UnlynxClient.java @@ -1,222 +1,216 @@ package ch.epfl.lca1.medco.unlynx; import java.io.*; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; import ch.epfl.lca1.medco.util.*; import ch.epfl.lca1.medco.util.exceptions.I2B2XMLException; import ch.epfl.lca1.medco.util.exceptions.UnlynxException; import org.jdom.Document; import org.jdom.Element; import org.jdom.JDOMException; import org.jdom.input.SAXBuilder; /** * Manages connection to the local Unlynx client. * It runs in a separate to execute the binary specified by the configuration. * * Start query by using executeQuery() and then join() to wait for the process to end. * Next check with getQueryState() the state of the query and finally get the end result with getQueryResult(). */ public class UnlynxClient { private String binPath; private String groupFilePath; private int debugLevel; private int entryPointIdx; private int computeProofsFlag; private long timeoutSeconds; private String lastTimingMeasurements; public UnlynxClient(String binPath, String groupFilePath, int debugLevel, int entryPointIdx, int computeProofsFlag, long timeoutSeconds) { this.binPath = binPath; this.groupFilePath = groupFilePath; this.debugLevel = debugLevel; this.entryPointIdx = entryPointIdx; this.computeProofsFlag = computeProofsFlag; this.timeoutSeconds = timeoutSeconds; } public String getLastTimingMeasurements() { return lastTimingMeasurements; } public List computeDistributedDetTags(String queryId, List encryptedQueryItems) throws UnlynxException, I2B2XMLException { // if empty save a request to unlynx if (encryptedQueryItems.size() == 0) { return new ArrayList<>(); - - // detect test case - } else if (encryptedQueryItems.contains(Constants.CONCEPT_NAME_TEST_FLAG)) { - List returnList = new ArrayList<>(); - returnList.add(Constants.CONCEPT_NAME_TEST_FLAG); - return returnList; } // generate input stdout StringBuilder sb = new StringBuilder(); sb.append(Constants.DDT_REQ_XML_START_TAG + "\n"); sb.append(""); sb.append(queryId); sb.append("\n"); sb.append("\n"); for (String encValue : encryptedQueryItems) { sb.append(""); sb.append(encValue); sb.append("\n"); } sb.append("\n"); sb.append(Constants.DDT_REQ_XML_END_TAG + "\n"); // run unlynx SystemBinaryRunThread process = new SystemBinaryRunThread(getUnlynxRunCall(), sb.toString(), timeoutSeconds); process.start(); process.waitForCompletion(); // process result if (process.getRunState() == SystemBinaryRunThread.RunState.COMPLETED) { Logger.info("Unlynx DDT request successfully completed"); return parseDistributedDetTagsCallResult(process.getStdIn()); } else { throw Logger.error(new UnlynxException("Unlynx DDT request failed, run state is: " + process.getRunState().toString())); } } public String aggregateData(String queryId, String clientPubKey, List encDummyFlags) throws UnlynxException, I2B2XMLException { // generate input stdout StringBuilder sb = new StringBuilder(); sb.append(Constants.AGG_REQ_XML_START_TAG + "\n"); sb.append(""); sb.append(queryId); sb.append("\n"); sb.append(""); sb.append(clientPubKey); sb.append(""); sb.append("\n"); for (String encFlag : encDummyFlags) { sb.append(""); sb.append(encFlag); sb.append("\n"); } sb.append("\n"); sb.append(Constants.AGG_REQ_XML_END_TAG + "\n"); // run unlynx SystemBinaryRunThread process = new SystemBinaryRunThread(getUnlynxRunCall(), sb.toString(), timeoutSeconds); process.start(); process.waitForCompletion(); // process result if (process.getRunState() == SystemBinaryRunThread.RunState.COMPLETED) { Logger.info("Unlynx DDT request successfully completed"); return parseAggregateCallResult(process.getStdIn()); } else { throw Logger.error(new UnlynxException("Unlynx DDT request failed, run state is: " + process.getRunState().toString())); } } /** * Construct the binary system call of the Unlynx client. * * @return array of tokens for system call to the Unlynx client */ private String[] getUnlynxRunCall() { ArrayList arr = new ArrayList<>(); arr.add(binPath); arr.add("-d"); arr.add(debugLevel + ""); arr.add("run"); arr.add("-f"); arr.add(groupFilePath); arr.add("--entryPointIdx"); arr.add(entryPointIdx + ""); arr.add("--proofs"); arr.add(computeProofsFlag + ""); return arr.toArray(new String[arr.size()]); } private List parseDistributedDetTagsCallResult(String stdinString) throws UnlynxException, I2B2XMLException { // XXX: hackish InputStream stdin = new ByteArrayInputStream(stdinString.getBytes(StandardCharsets.UTF_8)); String resultXMLString = XMLUtil.xmlStringFromStream(stdin, Constants.DDT_RESP_XML_START_TAG, Constants.DDT_RESP_XML_END_TAG, false); SAXBuilder sxb = new SAXBuilder(); try { Document doc = sxb.build(new ByteArrayInputStream(resultXMLString.getBytes(StandardCharsets.UTF_8))); Element root = doc.getRootElement(); // sanity check if (!root.getName().equals(Constants.DDT_RESP_XML_EL)) { throw Logger.error(new I2B2XMLException("XML not properly formed.")); } // error check, exception if yes String errorMsg = root.getChildTextNormalize("error"); if (errorMsg != null && !errorMsg.trim().isEmpty()) { throw Logger.error(new UnlynxException(errorMsg)); } // extract tagged values List encValuesXml = root.getChild("tagged_values").getChildren("tagged_value"); List encValues = new ArrayList<>(encValuesXml.size()); for (Object anEncValuesXml : encValuesXml) { encValues.add(((Element) anEncValuesXml).getValue()); } // extract times lastTimingMeasurements = root.getChildText("times"); return encValues; } catch(IOException | JDOMException e) { throw Logger.error(new I2B2XMLException("XML parsing error", e)); } } private String parseAggregateCallResult(String stdinString) throws UnlynxException, I2B2XMLException { // XXX: hackish InputStream stdin = new ByteArrayInputStream(stdinString.getBytes(StandardCharsets.UTF_8)); String resultXMLString = XMLUtil.xmlStringFromStream(stdin, Constants.AGG_RESP_XML_START_TAG, Constants.AGG_RESP_XML_END_TAG, false); SAXBuilder sxb = new SAXBuilder(); try { Document doc = sxb.build(new ByteArrayInputStream(resultXMLString.getBytes(StandardCharsets.UTF_8))); Element root = doc.getRootElement(); // sanity check if (!root.getName().equals(Constants.AGG_RESP_XML_EL)) { throw Logger.error(new I2B2XMLException("XML not properly formed.")); } // error check, exception if yes String errorMsg = root.getChildTextNormalize("error"); if (errorMsg != null && !errorMsg.trim().isEmpty()) { throw Logger.error(new UnlynxException(errorMsg)); } // extract times lastTimingMeasurements = root.getChildText("times"); // extract aggregated value return root.getChildText("aggregate"); } catch(IOException | JDOMException e) { throw Logger.error(new I2B2XMLException("XML parsing error", e)); } } }