diff --git a/adapter/adapter-service/src/main/scala/net/shrine/adapter/Adapter.scala b/adapter/adapter-service/src/main/scala/net/shrine/adapter/Adapter.scala index 9654cbbf9..3914fedc6 100644 --- a/adapter/adapter-service/src/main/scala/net/shrine/adapter/Adapter.scala +++ b/adapter/adapter-service/src/main/scala/net/shrine/adapter/Adapter.scala @@ -1,94 +1,104 @@ package net.shrine.adapter import java.sql.SQLException +import java.util.Date +import net.shrine.adapter.dao.BotDetectedException import net.shrine.log.Loggable import net.shrine.problem.{AbstractProblem, LoggingProblemHandler, Problem, ProblemNotYetEncoded, ProblemSources} import net.shrine.protocol.{AuthenticationInfo, BaseShrineResponse, BroadcastMessage, ErrorResponse, ShrineRequest} import scala.util.control.NonFatal /** * @author Bill Simons * @since 4/8/11 * @see http://cbmi.med.harvard.edu * @see http://chip.org *

* NOTICE: This software comes with NO guarantees whatsoever and is * licensed as Lgpl Open Source * @see http://www.gnu.org/licenses/lgpl.html */ abstract class Adapter extends Loggable { //noinspection RedundantBlock final def perform(message: BroadcastMessage): BaseShrineResponse = { def problemToErrorResponse(problem:Problem):ErrorResponse = { LoggingProblemHandler.handleProblem(problem) ErrorResponse(problem) } val shrineResponse = try { processRequest(message) } catch { case e: AdapterLockoutException => problemToErrorResponse(AdapterLockout(message.request.authn,e)) + case e: BotDetectedException => problemToErrorResponse(BotDetected(e)) + case e @ CrcInvocationException(invokedCrcUrl, request, cause) => problemToErrorResponse(CrcCouldNotBeInvoked(invokedCrcUrl,request,e)) case e: AdapterMappingException => problemToErrorResponse(AdapterMappingProblem(e)) case e: SQLException => problemToErrorResponse(AdapterDatabaseProblem(e)) case NonFatal(e) => { val summary = if(message == null) "Unknown problem in Adapter.perform with null BroadcastMessage" else s"Unexpected exception in Adapter" problemToErrorResponse(ProblemNotYetEncoded(summary,e)) } } shrineResponse } protected[adapter] def processRequest(message: BroadcastMessage): BaseShrineResponse //NOOP, may be overridden by subclasses def shutdown(): Unit = () } case class AdapterLockout(authn:AuthenticationInfo,x:AdapterLockoutException) extends AbstractProblem(ProblemSources.Adapter) { override val throwable = Some(x) override val summary: String = s"User '${authn.domain}:${authn.username}' locked out." override val description:String = s"User '${authn.domain}:${authn.username}' has run too many queries that produce the same result at ${x.url} ." createAndLog } case class CrcCouldNotBeInvoked(crcUrl:String,request:ShrineRequest,x:CrcInvocationException) extends AbstractProblem(ProblemSources.Adapter) { override val throwable = Some(x) override val summary: String = s"Error communicating with I2B2 CRC." override val description: String = s"Error invoking the CRC at '$crcUrl' with a ${request.getClass.getSimpleName} due to ${throwable.get}." override val detailsXml =

Request is {request} {throwableDetail.getOrElse("")}
createAndLog } case class AdapterMappingProblem(x:AdapterMappingException) extends AbstractProblem(ProblemSources.Adapter) { override val throwable = Some(x) override val summary: String = "Could not map query term(s)." override val description = s"The Shrine Adapter on ${stamp.host.getHostName} cannot map this query to its local terms." override val detailsXml =
Query Defitiontion is {x.runQueryRequest.queryDefinition} RunQueryRequest is ${x.runQueryRequest.elideAuthenticationInfo} {throwableDetail.getOrElse("")}
createAndLog } case class AdapterDatabaseProblem(x:SQLException) extends AbstractProblem(ProblemSources.Adapter) { override val throwable = Some(x) override val summary: String = "Problem using the Adapter database." override val description = "The Shrine Adapter encountered a problem using a database." createAndLog +} + +case class BotDetected(bdx:BotDetectedException) extends AbstractProblem(ProblemSources.Adapter) { + override val summary: String = s"A user has run enough queries in a short period of time the adapter suspects a bot." + + override val description: String = s"${bdx.domain}:${bdx.username} has run ${bdx.detectedCount} queries since ${new Date(bdx.sinceMs)}, more than the limit of ${bdx.limit} allowed in this time frame." } \ No newline at end of file diff --git a/adapter/adapter-service/src/main/scala/net/shrine/adapter/AdapterComponents.scala b/adapter/adapter-service/src/main/scala/net/shrine/adapter/AdapterComponents.scala index b1dd2d9b3..50943e969 100644 --- a/adapter/adapter-service/src/main/scala/net/shrine/adapter/AdapterComponents.scala +++ b/adapter/adapter-service/src/main/scala/net/shrine/adapter/AdapterComponents.scala @@ -1,136 +1,137 @@ package net.shrine.adapter import com.typesafe.config.Config import net.shrine.adapter.dao.{AdapterDao, I2b2AdminDao} import net.shrine.adapter.dao.squeryl.{SquerylAdapterDao, SquerylI2b2AdminDao} import net.shrine.adapter.dao.squeryl.tables.Tables import net.shrine.adapter.service.{AdapterService, I2b2AdminService} import net.shrine.adapter.translators.{ExpressionTranslator, QueryDefinitionTranslator} import net.shrine.client.{EndpointConfig, Poster} import net.shrine.config.mappings.{AdapterMappings, AdapterMappingsSource, ClasspathFormatDetectingAdapterMappingsSource} import net.shrine.crypto.{DefaultSignerVerifier, KeyStoreCertCollection} import net.shrine.dao.squeryl.SquerylInitializer import net.shrine.protocol.{HiveCredentials, NodeId, RequestType, ResultOutputType} import net.shrine.config.{ConfigExtensions, DurationConfigParser} /** * All the parts required for an adapter. * * @author david * @since 1.22 */ case class AdapterComponents( adapterService: AdapterService, i2b2AdminService: I2b2AdminService, adapterDao: AdapterDao, adapterMappings: AdapterMappings) object AdapterComponents { //todo try and trim this argument list back def apply( adapterConfig:Config, //config is "shrine.adapter" certCollection: KeyStoreCertCollection, squerylInitializer: SquerylInitializer, breakdownTypes: Set[ResultOutputType], crcHiveCredentials: HiveCredentials, signerVerifier: DefaultSignerVerifier, pmPoster: Poster, nodeId: NodeId ):AdapterComponents = { val crcEndpoint: EndpointConfig = adapterConfig.getConfigured("crcEndpoint",EndpointConfig(_)) val crcPoster: Poster = Poster(certCollection,crcEndpoint) val squerylAdapterTables: Tables = new Tables val adapterDao: AdapterDao = new SquerylAdapterDao(squerylInitializer, squerylAdapterTables)(breakdownTypes) //NB: Is i2b2HiveCredentials.projectId the right project id to use? val i2b2AdminDao: I2b2AdminDao = new SquerylI2b2AdminDao(crcHiveCredentials.projectId, squerylInitializer, squerylAdapterTables) val adapterMappingsFile = adapterConfig.getString("adapterMappingsFileName") val adapterMappingsSource: AdapterMappingsSource = ClasspathFormatDetectingAdapterMappingsSource(adapterMappingsFile) //NB: Fail fast val adapterMappings: AdapterMappings = adapterMappingsSource.load(adapterMappingsFile).get val expressionTranslator: ExpressionTranslator = ExpressionTranslator(adapterMappings) val queryDefinitionTranslator: QueryDefinitionTranslator = new QueryDefinitionTranslator(expressionTranslator) val doObfuscation = adapterConfig.getBoolean("setSizeObfuscation") val collectAdapterAudit = adapterConfig.getBoolean("audit.collectAdapterAudit") val runQueryAdapter = new RunQueryAdapter( crcPoster, adapterDao, crcHiveCredentials, queryDefinitionTranslator, adapterConfig.getInt("adapterLockoutAttemptsThreshold"), doObfuscation, adapterConfig.getOption("immediatelyRunIncomingQueries", _.getBoolean).getOrElse(true), //todo use reference.conf breakdownTypes, - collectAdapterAudit + collectAdapterAudit, + Map.empty //todo pull the map out of config ) val readInstanceResultsAdapter: Adapter = new ReadInstanceResultsAdapter( crcPoster, crcHiveCredentials, adapterDao, doObfuscation, breakdownTypes, collectAdapterAudit ) val readQueryResultAdapter: Adapter = new ReadQueryResultAdapter( crcPoster, crcHiveCredentials, adapterDao, doObfuscation, breakdownTypes, collectAdapterAudit ) val readPreviousQueriesAdapter: Adapter = new ReadPreviousQueriesAdapter(adapterDao) val deleteQueryAdapter: Adapter = new DeleteQueryAdapter(adapterDao) val renameQueryAdapter: Adapter = new RenameQueryAdapter(adapterDao) val readQueryDefinitionAdapter: Adapter = new ReadQueryDefinitionAdapter(adapterDao) val readTranslatedQueryDefinitionAdapter: Adapter = new ReadTranslatedQueryDefinitionAdapter(nodeId, queryDefinitionTranslator) val flagQueryAdapter: Adapter = new FlagQueryAdapter(adapterDao) val unFlagQueryAdapter: Adapter = new UnFlagQueryAdapter(adapterDao) val adapterMap = AdapterMap(Map( RequestType.QueryDefinitionRequest -> runQueryAdapter, RequestType.GetRequestXml -> readQueryDefinitionAdapter, RequestType.UserRequest -> readPreviousQueriesAdapter, RequestType.InstanceRequest -> readInstanceResultsAdapter, RequestType.MasterDeleteRequest -> deleteQueryAdapter, RequestType.MasterRenameRequest -> renameQueryAdapter, RequestType.GetQueryResult -> readQueryResultAdapter, RequestType.ReadTranslatedQueryDefinitionRequest -> readTranslatedQueryDefinitionAdapter, RequestType.FlagQueryRequest -> flagQueryAdapter, RequestType.UnFlagQueryRequest -> unFlagQueryAdapter)) AdapterComponents( adapterService = new AdapterService( nodeId = nodeId, signatureVerifier = signerVerifier, maxSignatureAge = adapterConfig.getConfigured("maxSignatureAge", DurationConfigParser(_)), adapterMap = adapterMap ), i2b2AdminService = new I2b2AdminService( dao = adapterDao, i2b2AdminDao = i2b2AdminDao, pmPoster = pmPoster, runQueryAdapter = runQueryAdapter ), adapterDao = adapterDao, adapterMappings = adapterMappings) } } \ No newline at end of file diff --git a/adapter/adapter-service/src/main/scala/net/shrine/adapter/RunQueryAdapter.scala b/adapter/adapter-service/src/main/scala/net/shrine/adapter/RunQueryAdapter.scala index 743858a11..9dd269990 100644 --- a/adapter/adapter-service/src/main/scala/net/shrine/adapter/RunQueryAdapter.scala +++ b/adapter/adapter-service/src/main/scala/net/shrine/adapter/RunQueryAdapter.scala @@ -1,287 +1,291 @@ package net.shrine.adapter import net.shrine.adapter.audit.AdapterAuditDb import scala.util.Failure import scala.util.Success import scala.util.Try import scala.xml.NodeSeq import net.shrine.adapter.dao.AdapterDao import net.shrine.adapter.translators.QueryDefinitionTranslator import net.shrine.protocol.{AuthenticationInfo, BroadcastMessage, Credential, ErrorFromCrcException, ErrorResponse, HiveCredentials, I2b2ResultEnvelope, MissingCrCXmlResultException, QueryResult, RawCrcRunQueryResponse, ReadResultRequest, ReadResultResponse, ResultOutputType, RunQueryRequest, RunQueryResponse, ShrineResponse} import net.shrine.client.Poster import net.shrine.problem.{AbstractProblem, LoggingProblemHandler, Problem, ProblemNotYetEncoded, ProblemSources} import scala.util.control.NonFatal import net.shrine.util.XmlDateHelper +import scala.concurrent.duration.Duration import scala.xml.XML /** * @author Bill Simons * @author clint * @since 4/15/11 * @see http://cbmi.med.harvard.edu * @see http://chip.org *

* NOTICE: This software comes with NO guarantees whatsoever and is * licensed as Lgpl Open Source * @see http://www.gnu.org/licenses/lgpl.html */ final case class RunQueryAdapter( poster: Poster, dao: AdapterDao, override val hiveCredentials: HiveCredentials, conceptTranslator: QueryDefinitionTranslator, adapterLockoutAttemptsThreshold: Int, doObfuscation: Boolean, runQueriesImmediately: Boolean, breakdownTypes: Set[ResultOutputType], - collectAdapterAudit:Boolean + collectAdapterAudit:Boolean, + botCountTimeThresholds:Map[Long,Duration] ) extends CrcAdapter[RunQueryRequest, RunQueryResponse](poster, hiveCredentials) { logStartup() import RunQueryAdapter._ override protected[adapter] def parseShrineResponse(xml: NodeSeq) = RawCrcRunQueryResponse.fromI2b2(breakdownTypes)(xml).get //TODO: Avoid .get call override protected[adapter] def translateNetworkToLocal(request: RunQueryRequest): RunQueryRequest = { try { request.mapQueryDefinition(conceptTranslator.translate) } catch { case NonFatal(e) => throw new AdapterMappingException(request,s"Error mapping query terms from network to local forms.", e) } } override protected[adapter] def processRequest(message: BroadcastMessage): ShrineResponse = { if (collectAdapterAudit) AdapterAuditDb.db.insertQueryReceived(message) if (isLockedOut(message.networkAuthn)) { throw new AdapterLockoutException(message.networkAuthn,poster.url) } + dao.checkIfBot(message.networkAuthn,botCountTimeThresholds) + val runQueryReq = message.request.asInstanceOf[RunQueryRequest] //We need to use the network identity from the BroadcastMessage, since that will have the network username //(ie, ecommons) of the querying user. Using the AuthenticationInfo from the incoming request breaks the fetching //of previous queries on deployed systems where the credentials in the identity param to this method and the authn //field of the incoming request are different, like the HMS Shrine deployment. //NB: Credential field is wiped out to preserve old behavior -Clint 14 Nov, 2013 val authnToUse = message.networkAuthn.copy(credential = Credential("", isToken = false)) if (!runQueriesImmediately) { debug(s"Queueing query from user ${message.networkAuthn.domain}:${message.networkAuthn.username}") storeQuery(authnToUse, message, runQueryReq) } else { debug(s"Performing query from user ${message.networkAuthn.domain}:${message.networkAuthn.username}") val result: ShrineResponse = runQuery(authnToUse, message.copy(request = runQueryReq.withAuthn(authnToUse)), runQueryReq.withAuthn(authnToUse)) if (collectAdapterAudit) AdapterAuditDb.db.insertResultSent(runQueryReq.networkQueryId,result) result } } private def storeQuery(authnToUse: AuthenticationInfo, message: BroadcastMessage, request: RunQueryRequest): RunQueryResponse = { //Use dummy ids for what we would have received from the CRC val masterId: Long = -1L val queryInstanceId: Long = -1L val resultId: Long = -1L //TODO: is this right?? Or maybe it's project id? val groupId = authnToUse.domain val invalidSetSize = -1L val now = XmlDateHelper.now val queryResult = QueryResult(resultId, queryInstanceId, Some(ResultOutputType.PATIENT_COUNT_XML), invalidSetSize, Some(now), Some(now), Some("Query enqueued for later processing"), QueryResult.StatusType.Held, Some("Query enqueued for later processing")) dao.inTransaction { val insertedQueryId = dao.insertQuery(masterId.toString, request.networkQueryId, authnToUse, request.queryDefinition, isFlagged = false, hasBeenRun = false, flagMessage = None) val insertedQueryResultIds = dao.insertQueryResults(insertedQueryId, Seq(queryResult)) //NB: We need to insert dummy QueryResult and Count records so that calls to StoredQueries.retrieve() in //AbstractReadQueryResultAdapter, called when retrieving results for previously-queued-or-incomplete //queries, will work. val countQueryResultId = insertedQueryResultIds(ResultOutputType.PATIENT_COUNT_XML).head dao.insertCountResult(countQueryResultId, -1L, -1L) } RunQueryResponse(masterId, XmlDateHelper.now, authnToUse.username, groupId, request.queryDefinition, queryInstanceId, queryResult) } private def runQuery(authnToUse: AuthenticationInfo, message: BroadcastMessage, request: RunQueryRequest): ShrineResponse = { if (collectAdapterAudit) AdapterAuditDb.db.insertExecutionStarted(request) //NB: Pass through ErrorResponses received from the CRC. //See: https://open.med.harvard.edu/jira/browse/SHRINE-794 val result = super.processRequest(message) match { case e: ErrorResponse => e case rawRunQueryResponse: RawCrcRunQueryResponse => processRawCrcRunQueryResponse(authnToUse, request, rawRunQueryResponse) } if (collectAdapterAudit) AdapterAuditDb.db.insertExecutionCompletedShrineResponse(request,result) result } private[adapter] def processRawCrcRunQueryResponse(authnToUse: AuthenticationInfo, request: RunQueryRequest, rawRunQueryResponse: RawCrcRunQueryResponse): RunQueryResponse = { def isBreakdown(result: QueryResult) = result.resultType.exists(_.isBreakdown) val originalResults: Seq[QueryResult] = rawRunQueryResponse.results val (originalBreakdownResults, originalNonBreakDownResults): (Seq[QueryResult],Seq[QueryResult]) = originalResults.partition(isBreakdown) val originalBreakdownCountAttempts: Seq[(QueryResult, Try[QueryResult])] = attemptToRetrieveBreakdowns(request, originalBreakdownResults) val (successfulBreakdownCountAttempts, failedBreakdownCountAttempts) = originalBreakdownCountAttempts.partition { case (_, t) => t.isSuccess } val failedBreakdownCountAttemptsWithProblems = failedBreakdownCountAttempts.map { attempt => val originalResult: QueryResult = attempt._1 val queryResult:QueryResult = if (originalResult.problemDigest.isDefined) originalResult else { attempt._2 match { case Success(_) => originalResult case Failure(x) => //noinspection RedundantBlock { val problem:Problem = x match { case e: ErrorFromCrcException => ErrorFromCrcBreakdown(e) case e: MissingCrCXmlResultException => CannotInterpretCrcBreakdownXml(e) case NonFatal(e) => { val summary = s"Unexpected exception while interpreting breakdown response" ProblemNotYetEncoded(summary, e) } } LoggingProblemHandler.handleProblem(problem) originalResult.copy(problemDigest = Some(problem.toDigest)) } } } (queryResult,attempt._2) } logBreakdownFailures(rawRunQueryResponse, failedBreakdownCountAttemptsWithProblems) val originalMergedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope] = { val withBreakdownCounts = successfulBreakdownCountAttempts.collect { case (_, Success(queryResultWithBreakdowns)) => queryResultWithBreakdowns } withBreakdownCounts.map(_.breakdowns).fold(Map.empty)(_ ++ _) } val obfuscatedQueryResults = originalResults.map(Obfuscator.obfuscate) val obfuscatedNonBreakdownQueryResults = obfuscatedQueryResults.filterNot(isBreakdown) val obfuscatedMergedBreakdowns = obfuscateBreakdowns(originalMergedBreakdowns) val failedBreakdownTypes = failedBreakdownCountAttemptsWithProblems.flatMap { case (qr, _) => qr.resultType } dao.storeResults( authn = authnToUse, masterId = rawRunQueryResponse.queryId.toString, networkQueryId = request.networkQueryId, queryDefinition = request.queryDefinition, rawQueryResults = originalResults, obfuscatedQueryResults = obfuscatedQueryResults, failedBreakdownTypes = failedBreakdownTypes, mergedBreakdowns = originalMergedBreakdowns, obfuscatedBreakdowns = obfuscatedMergedBreakdowns) // at this point the queryResult could be a mix of successes and failures. // SHRINE reports only the successes. See SHRINE-1567 for details val queryResults: Seq[QueryResult] = if (doObfuscation) obfuscatedNonBreakdownQueryResults else originalNonBreakDownResults val breakdownsToReturn: Map[ResultOutputType, I2b2ResultEnvelope] = if (doObfuscation) obfuscatedMergedBreakdowns else originalMergedBreakdowns //TODO: Will fail in the case of NO non-breakdown QueryResults. Can this ever happen, and is it worth protecting against here? //can failedBreakdownCountAttempts be mixed back in here? val resultWithBreakdowns: QueryResult = queryResults.head.withBreakdowns(breakdownsToReturn) if(debugEnabled) { def justBreakdowns(breakdowns: Map[ResultOutputType, I2b2ResultEnvelope]) = breakdowns.mapValues(_.data) val obfuscationMessage = s"obfuscation is ${if(doObfuscation) "ON" else "OFF"}" debug(s"Returning QueryResult with count ${resultWithBreakdowns.setSize} (original count: ${originalNonBreakDownResults.headOption.map(_.setSize)} ; $obfuscationMessage)") debug(s"Returning QueryResult with breakdowns ${justBreakdowns(resultWithBreakdowns.breakdowns)} (original breakdowns: ${justBreakdowns(originalMergedBreakdowns)} ; $obfuscationMessage)") debug(s"Full QueryResult: $resultWithBreakdowns") } //if any results had problems, this commented out code can turn it into an error QueryResult //See SHRINE-1619 //val problem: Option[ProblemDigest] = failedBreakdownCountAttemptsWithProblems.headOption.flatMap(x => x._1.problemDigest) //val queryResult = problem.fold(resultWithBreakdowns)(pd => QueryResult.errorResult(Some(pd.description),"Error with CRC",pd)) rawRunQueryResponse.toRunQueryResponse.withResult(resultWithBreakdowns) } private def getResultFromCrc(parentRequest: RunQueryRequest, networkResultId: Long): Try[ReadResultResponse] = { def readResultRequest(runQueryReq: RunQueryRequest, networkResultId: Long) = ReadResultRequest(hiveCredentials.projectId, runQueryReq.waitTime, hiveCredentials.toAuthenticationInfo, networkResultId.toString) Try(XML.loadString(callCrc(readResultRequest(parentRequest, networkResultId)))).flatMap(ReadResultResponse.fromI2b2(breakdownTypes)) } private[adapter] def attemptToRetrieveCount(runQueryReq: RunQueryRequest, originalCountQueryResult: QueryResult): (QueryResult, Try[QueryResult]) = { originalCountQueryResult -> (for { countData <- getResultFromCrc(runQueryReq, originalCountQueryResult.resultId) } yield originalCountQueryResult.withSetSize(countData.metadata.setSize)) } private[adapter] def attemptToRetrieveBreakdowns(runQueryReq: RunQueryRequest, breakdownResults: Seq[QueryResult]): Seq[(QueryResult, Try[QueryResult])] = { breakdownResults.map { origBreakdownResult => origBreakdownResult -> (for { breakdownData <- getResultFromCrc(runQueryReq, origBreakdownResult.resultId).map(_.data) } yield origBreakdownResult.withBreakdown(breakdownData)) } } private[adapter] def logBreakdownFailures(response: RawCrcRunQueryResponse, failures: Seq[(QueryResult, Try[QueryResult])]) { for { (origQueryResult, Failure(e)) <- failures } { error(s"Couldn't load breakdown for QueryResult with masterId: ${response.queryId}, instanceId: ${origQueryResult.instanceId}, resultId: ${origQueryResult.resultId}. Asked for result type: ${origQueryResult.resultType}", e) } } private def isLockedOut(authn: AuthenticationInfo): Boolean = { adapterLockoutAttemptsThreshold match { case 0 => false case _ => dao.isUserLockedOut(authn, adapterLockoutAttemptsThreshold) } } private def logStartup(): Unit = { val message = { if (runQueriesImmediately) { s"${getClass.getSimpleName} will run queries immediately" } else { s"${getClass.getSimpleName} will queue queries for later execution" } } info(message) } } object RunQueryAdapter { private[adapter] def obfuscateBreakdowns(breakdowns: Map[ResultOutputType, I2b2ResultEnvelope]): Map[ResultOutputType, I2b2ResultEnvelope] = { breakdowns.mapValues(_.mapValues(Obfuscator.obfuscate)) } } case class ErrorFromCrcBreakdown(x:ErrorFromCrcException) extends AbstractProblem(ProblemSources.Adapter) { override val throwable = Some(x) override val summary: String = "The CRC reported an error." override val description = "The CRC reported an internal error." createAndLog } case class CannotInterpretCrcBreakdownXml(x:MissingCrCXmlResultException) extends AbstractProblem(ProblemSources.Adapter) { override val throwable = Some(x) override val summary: String = "SHRINE cannot interpret the CRC response." override val description = "The CRC responded, but SHRINE could not interpret that response." createAndLog } \ No newline at end of file diff --git a/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/AdapterDao.scala b/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/AdapterDao.scala index 50fd75c72..02fdf6bf1 100644 --- a/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/AdapterDao.scala +++ b/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/AdapterDao.scala @@ -1,73 +1,91 @@ package net.shrine.adapter.dao +import java.util.Date + import net.shrine.protocol.query.QueryDefinition import net.shrine.protocol.AuthenticationInfo import net.shrine.adapter.dao.model.ShrineQueryResult import net.shrine.protocol.QueryResult import net.shrine.protocol.I2b2ResultEnvelope import net.shrine.protocol.ResultOutputType import net.shrine.adapter.dao.model.ShrineQuery +import scala.concurrent.duration.Duration import scala.xml.NodeSeq /** * @author clint * @since Oct 15, 2012 */ trait AdapterDao { /** * @return the id column of the inserted row */ def insertQuery(masterId: String, networkId: Long, authn: AuthenticationInfo, queryDefinition: QueryDefinition, isFlagged: Boolean, hasBeenRun: Boolean, flagMessage: Option[String]): Int //Returns a Map of output types to Seqs of inserted ids, since the ERROR output type can be used for multiple query_result rows, //Say for a run query operation that results in multiple error responses from the CRC. def insertQueryResults(parentQueryId: Int, results: Seq[QueryResult]): Map[ResultOutputType, Seq[Int]] def insertCountResult(resultId: Int, originalCount: Long, obfuscatedCount: Long): Unit def insertBreakdownResults(parentResultIds: Map[ResultOutputType, Seq[Int]], originalBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope], obfuscatedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope]): Unit def insertErrorResult(parentResultId: Int, errorMessage: String, codec:String, stampText:String, summary:String, digestDescription:String,detailsXml:NodeSeq): Unit def findQueriesByUserAndDomain(domain: String, username: String, howMany: Int): Seq[ShrineQuery] def findQueriesByDomain(domain:String):Seq[ShrineQuery] def findQueryByNetworkId(networkQueryId: Long): Option[ShrineQuery] def findResultsFor(networkQueryId: Long): Option[ShrineQueryResult] - + def isUserLockedOut(authn: AuthenticationInfo, defaultThreshold: Int): Boolean - + + /** + * @throws BotDetectedException if it detects a bot attack + */ + def checkIfBot(authn:AuthenticationInfo, countTimeThresholds:Map[Long,Duration]): Unit + def renameQuery(networkQueryId: Long, newName: String): Unit def deleteQuery(networkQueryId: Long): Unit def deleteQueryResultsFor(networkQueryId: Long): Unit def findRecentQueries(howMany: Int): Seq[ShrineQuery] def storeResults(authn: AuthenticationInfo, masterId: String, networkQueryId: Long, queryDefinition: QueryDefinition, rawQueryResults: Seq[QueryResult], obfuscatedQueryResults: Seq[QueryResult], failedBreakdownTypes: Seq[ResultOutputType], mergedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope], obfuscatedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope]): Unit def flagQuery(networkQueryId: Long, message: Option[String]): Unit def unFlagQuery(networkQueryId: Long): Unit def inTransaction[T](f: => T): T = f } + + +case class BotDetectedException(domain:String, + username:String, + detectedCount:Long, + sinceMs:Long, + limit:Long) extends Exception() { + + override def getMessage = s"$domain:$username has run $detectedCount queries since ${new Date(sinceMs)}, more than the limit of $limit" +} diff --git a/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/squeryl/SquerylAdapterDao.scala b/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/squeryl/SquerylAdapterDao.scala index 14d448e39..9c36a5bf8 100644 --- a/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/squeryl/SquerylAdapterDao.scala +++ b/adapter/adapter-service/src/main/scala/net/shrine/adapter/dao/squeryl/SquerylAdapterDao.scala @@ -1,465 +1,490 @@ package net.shrine.adapter.dao.squeryl +import java.sql.Timestamp import javax.xml.datatype.XMLGregorianCalendar -import net.shrine.adapter.dao.AdapterDao +import net.shrine.adapter.dao.{AdapterDao, BotDetectedException} import net.shrine.adapter.dao.model.{ObfuscatedPair, ShrineQuery, ShrineQueryResult} import net.shrine.adapter.dao.model.squeryl.{SquerylBreakdownResultRow, SquerylCountRow, SquerylPrivilegedUser, SquerylQueryResultRow, SquerylShrineError, SquerylShrineQuery} import net.shrine.adapter.dao.squeryl.tables.Tables import net.shrine.dao.DateHelpers import net.shrine.dao.squeryl.{SquerylEntryPoint, SquerylInitializer} import net.shrine.log.Loggable import net.shrine.problem.{AbstractProblem, ProblemSources} import net.shrine.protocol.{AuthenticationInfo, I2b2ResultEnvelope, QueryResult, ResultOutputType} import net.shrine.protocol.query.QueryDefinition import net.shrine.util.XmlDateHelper import org.squeryl.Query -import org.squeryl.dsl.GroupWithMeasures +import org.squeryl.dsl.{GroupWithMeasures, Measures} +import scala.concurrent.duration.Duration import scala.util.Try import scala.xml.NodeSeq /** * @author clint * @since May 22, 2013 */ final class SquerylAdapterDao(initializer: SquerylInitializer, tables: Tables)(implicit breakdownTypes: Set[ResultOutputType]) extends AdapterDao with Loggable { initializer.init() override def inTransaction[T](f: => T): T = SquerylEntryPoint.inTransaction { f } import SquerylEntryPoint._ override def flagQuery(networkQueryId: Long, flagMessage: Option[String]): Unit = mutateFlagField(networkQueryId, newIsFlagged = true, flagMessage) override def unFlagQuery(networkQueryId: Long): Unit = mutateFlagField(networkQueryId, newIsFlagged = false, None) private def mutateFlagField(networkQueryId: Long, newIsFlagged: Boolean, newFlagMessage: Option[String]): Unit = { inTransaction { update(tables.shrineQueries) { queryRow => where(queryRow.networkId === networkQueryId). set(queryRow.isFlagged := newIsFlagged, queryRow.flagMessage := newFlagMessage) } } } override def storeResults( authn: AuthenticationInfo, masterId: String, networkQueryId: Long, queryDefinition: QueryDefinition, rawQueryResults: Seq[QueryResult], obfuscatedQueryResults: Seq[QueryResult], failedBreakdownTypes: Seq[ResultOutputType], mergedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope], obfuscatedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope]): Unit = { inTransaction { val insertedQueryId = insertQuery(masterId, networkQueryId, authn, queryDefinition, isFlagged = false, hasBeenRun = true, flagMessage = None) val insertedQueryResultIds = insertQueryResults(insertedQueryId, rawQueryResults) storeCountResults(rawQueryResults, obfuscatedQueryResults, insertedQueryResultIds) storeErrorResults(rawQueryResults, insertedQueryResultIds) storeBreakdownFailures(failedBreakdownTypes.toSet, insertedQueryResultIds) insertBreakdownResults(insertedQueryResultIds, mergedBreakdowns, obfuscatedBreakdowns) } } private[adapter] def storeCountResults(raw: Seq[QueryResult], obfuscated: Seq[QueryResult], insertedIds: Map[ResultOutputType, Seq[Int]]): Unit = { val notErrors = raw.filter(!_.isError) val obfuscatedNotErrors = obfuscated.filter(!_.isError) if(notErrors.size > 1) { warn(s"Got ${notErrors.size} raw (hopefully-)count results; more than 1 is unusual.") } if(obfuscatedNotErrors.size > 1) { warn(s"Got ${obfuscatedNotErrors.size} obfuscated (hopefully-)count results; more than 1 is unusual.") } if(notErrors.size != obfuscatedNotErrors.size) { warn(s"Got ${notErrors.size} raw and ${obfuscatedNotErrors.size} obfuscated (hopefully-)count results; that these numbers are different is unusual.") } import ResultOutputType.PATIENT_COUNT_XML def isCount(qr: QueryResult): Boolean = qr.resultType.contains(PATIENT_COUNT_XML) inTransaction { //NB: Take the count/setSize from the FIRST PATIENT_COUNT_XML QueryResult, //though the same count should be there for all of them, if there are more than one for { Seq(insertedCountQueryResultId) <- insertedIds.get(PATIENT_COUNT_XML) notError <- notErrors.find(isCount) //NB: Find a count result, just to be sure obfuscatedNotError <- obfuscatedNotErrors.find(isCount) //NB: Find a count result, just to be sure } { insertCountResult(insertedCountQueryResultId, notError.setSize, obfuscatedNotError.setSize) } } } private[adapter] def storeErrorResults(results: Seq[QueryResult], insertedIds: Map[ResultOutputType, Seq[Int]]): Unit = { val errors = results.filter(_.isError) val insertedErrorResultIds = insertedIds.getOrElse(ResultOutputType.ERROR,Nil) val insertedIdsToErrors = insertedErrorResultIds zip errors inTransaction { for { (insertedErrorResultId, errorQueryResult) <- insertedIdsToErrors } { val pd = errorQueryResult.problemDigest.get //it's an error so it will have a problem digest insertErrorResult( insertedErrorResultId, errorQueryResult.statusMessage.getOrElse("Unknown failure"), pd.codec, pd.stampText, pd.summary, pd.description, pd.detailsXml ) } } } private[adapter] def storeBreakdownFailures(failedBreakdownTypes: Set[ResultOutputType], insertedIds: Map[ResultOutputType, Seq[Int]]): Unit = { val insertedIdsForFailedBreakdownTypes = insertedIds.filterKeys(failedBreakdownTypes.contains) inTransaction { for { (failedBreakdownType, Seq(resultId)) <- insertedIdsForFailedBreakdownTypes } { //todo propagate backwards to the breakdown failure to create the corect problem object BreakdownFailure extends AbstractProblem(ProblemSources.Adapter) { override val summary: String = "Couldn't retrieve result breakdown" override val description:String = s"Couldn't retrieve result breakdown of type '$failedBreakdownType'" createAndLog } val pd = BreakdownFailure.toDigest insertErrorResult( resultId, s"Couldn't retrieve breakdown of type '$failedBreakdownType'", pd.codec, pd.stampText, pd.summary, pd.description, pd.detailsXml ) } } } override def findRecentQueries(howMany: Int): Seq[ShrineQuery] = { inTransaction { Queries.queriesForAllUsers.take(howMany).map(_.toShrineQuery).toSeq } } def findAllCounts():Seq[SquerylCountRow] = { inTransaction{ Queries.allCountResults.toSeq } } override def renameQuery(networkQueryId: Long, newName: String) { inTransaction { update(tables.shrineQueries) { queryRow => where(queryRow.networkId === networkQueryId). set(queryRow.name := newName) } } } override def deleteQuery(networkQueryId: Long): Unit = { inTransaction { tables.shrineQueries.deleteWhere(_.networkId === networkQueryId) } } override def deleteQueryResultsFor(networkQueryId: Long): Unit = { inTransaction { val resultIdsForNetworkQueryId = join(tables.shrineQueries, tables.queryResults) { (queryRow, resultRow) => where(queryRow.networkId === networkQueryId). select(resultRow.id). on(queryRow.id === resultRow.queryId) }.toSet tables.queryResults.deleteWhere(_.id in resultIdsForNetworkQueryId) } } override def isUserLockedOut(authn: AuthenticationInfo, defaultThreshold: Int): Boolean = Try { inTransaction { val privilegedUserOption = Queries.privilegedUsers(authn.domain, authn.username).singleOption val threshold:Int = privilegedUserOption.flatMap(_.threshold).getOrElse(defaultThreshold.intValue) val thirtyDaysInThePast: XMLGregorianCalendar = DateHelpers.daysFromNow(-30) val overrideDate: XMLGregorianCalendar = privilegedUserOption.map(_.toPrivilegedUser).flatMap(_.overrideDate).getOrElse(thirtyDaysInThePast) //sorted instead of just finding max val counts: Seq[Long] = Queries.repeatedResults(authn.domain, authn.username, overrideDate).toSeq.sorted //and then grabbing the last, highest value in the sorted sequence val repeatedResultCount: Long = counts.lastOption.getOrElse(0L) val result = repeatedResultCount > threshold debug(s"User ${authn.domain}:${authn.username} locked out? $result") result } }.getOrElse(false) + override def checkIfBot(authn:AuthenticationInfo, botTimeThresholds:Map[Long,Duration]): Unit = { + val now = System.currentTimeMillis() + + botTimeThresholds.foreach{countDuration => inTransaction { + val sinceMs: Long = now - countDuration._2.toMillis + val query: Query[Measures[Long]] = Queries.countQueriesForUserSince(authn.domain, authn.username, sinceMs) + val queriesSince = query.headOption.map(_.measures).getOrElse(0L) + if (queriesSince > countDuration._1) throw new BotDetectedException(domain = authn.domain, + username = authn.username, + detectedCount = queriesSince, + sinceMs = sinceMs, + limit = countDuration._1) + }} + } + override def insertQuery(localMasterId: String, networkId: Long, authn: AuthenticationInfo, queryDefinition: QueryDefinition, isFlagged: Boolean, hasBeenRun: Boolean, flagMessage: Option[String]): Int = { inTransaction { val inserted = tables.shrineQueries.insert(new SquerylShrineQuery( 0, localMasterId, networkId, authn.username, authn.domain, XmlDateHelper.now, isFlagged, flagMessage, hasBeenRun, queryDefinition)) inserted.id } } /** * Insert rows into QueryResults, one for each QueryResult in the passed RunQueryResponse * Inserted rows are 'children' of the passed ShrineQuery (ie, they are the results of the query) */ override def insertQueryResults(parentQueryId: Int, results: Seq[QueryResult]): Map[ResultOutputType, Seq[Int]] = { def execTime(result: QueryResult): Option[Long] = { //TODO: How are locales handled here? Do we care? def toMillis(xmlGc: XMLGregorianCalendar) = xmlGc.toGregorianCalendar.getTimeInMillis for { start <- result.startDate end <- result.endDate } yield toMillis(end) - toMillis(start) } val typeToIdTuples = inTransaction { for { result <- results resultType = result.resultType.getOrElse(ResultOutputType.ERROR) //TODO: under what circumstances can QueryResults NOT have start and end dates set? elapsed = execTime(result) } yield { val lastInsertedQueryResultRow = tables.queryResults.insert(new SquerylQueryResultRow(0, result.resultId, parentQueryId, resultType, result.statusType, elapsed, XmlDateHelper.now)) (resultType, lastInsertedQueryResultRow.id) } } typeToIdTuples.groupBy { case (resultType, _) => resultType }.mapValues(_.map { case (_, count) => count }) } override def insertCountResult(resultId: Int, originalCount: Long, obfuscatedCount: Long) { //NB: Squeryl steers us toward inserting with dummy ids :( inTransaction { tables.countResults.insert(new SquerylCountRow(0, resultId, originalCount, obfuscatedCount, XmlDateHelper.now)) } } override def insertBreakdownResults(parentResultIds: Map[ResultOutputType, Seq[Int]], originalBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope], obfuscatedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope]) { def merge(original: I2b2ResultEnvelope, obfuscated: I2b2ResultEnvelope): Map[String, ObfuscatedPair] = { Map.empty ++ (for { (key, originalValue) <- original.data obfuscatedValue <- obfuscated.data.get(key) } yield (key, ObfuscatedPair(originalValue, obfuscatedValue))) } inTransaction { for { (resultType, Seq(resultId)) <- parentResultIds if resultType.isBreakdown originalBreakdown <- originalBreakdowns.get(resultType) obfuscatedBreakdown <- obfuscatedBreakdowns.get(resultType) (key, ObfuscatedPair(original, obfuscated)) <- merge(originalBreakdown, obfuscatedBreakdown) } { tables.breakdownResults.insert(SquerylBreakdownResultRow(0, resultId, key, original, obfuscated)) } } } override def insertErrorResult(parentResultId: Int, errorMessage: String, codec:String, stampText:String, summary:String, digestDescription:String,detailsXml:NodeSeq) { //NB: Squeryl steers us toward inserting with dummy ids :( inTransaction { tables.errorResults.insert(SquerylShrineError(0, parentResultId, errorMessage, codec, stampText, summary, digestDescription, detailsXml.toString())) } } override def findQueryByNetworkId(networkQueryId: Long): Option[ShrineQuery] = { inTransaction { Queries.queriesByNetworkId(networkQueryId).headOption.map(_.toShrineQuery) } } override def findQueriesByUserAndDomain(domain: String, username: String, howMany: Int): Seq[ShrineQuery] = { inTransaction { Queries.queriesForUser(username, domain).take(howMany).toSeq.map(_.toShrineQuery) } } override def findQueriesByDomain(domain: String): Seq[ShrineQuery] = { inTransaction { Queries.queriesForDomain(domain).toList.map(_.toShrineQuery) } } override def findResultsFor(networkQueryId: Long): Option[ShrineQueryResult] = { inTransaction { val breakdownRowsByType = Queries.breakdownResults(networkQueryId).toSeq.groupBy { case (outputType, _) => outputType.toQueryResultRow.resultType }.mapValues(_.map { case (_, row) => row.toBreakdownResultRow }) val queryRowOption = Queries.queriesByNetworkId(networkQueryId).headOption.map(_.toShrineQuery) val countRowOption = Queries.countResults(networkQueryId).headOption.map(_.toCountRow) val queryResultRows = Queries.resultsForQuery(networkQueryId).toSeq.map(_.toQueryResultRow) val errorResultRows = Queries.errorResults(networkQueryId).toSeq.map(_.toShrineError) for { queryRow <- queryRowOption countRow <- countRowOption shrineQueryResult <- ShrineQueryResult.fromRows(queryRow, queryResultRows, countRow, breakdownRowsByType, errorResultRows) } yield { shrineQueryResult } } } /** * @author clint * @since Nov 19, 2012 */ object Queries { def privilegedUsers(domain: String, username: String): Query[SquerylPrivilegedUser] = { from(tables.privilegedUsers) { user => where(user.username === username and user.domain === domain).select(user) } } + def countQueriesForUserSince(domain:String, username:String, sinceMs:Long): Query[Measures[Long]] = { + val since = new Timestamp(sinceMs) + from(tables.shrineQueries) { queryRow => + where(queryRow.domain === domain and queryRow.username === username and queryRow.dateCreated >= since). + compute(count) + } + } + def repeatedResults(domain: String, username: String, overrideDate: XMLGregorianCalendar): Query[Long] = { val counts: Query[GroupWithMeasures[Long, Long]] = join(tables.shrineQueries, tables.queryResults, tables.countResults) { (queryRow, resultRow, countRow) => where(queryRow.username === username and queryRow.domain === domain and (countRow.originalValue <> 0L) and queryRow.dateCreated > DateHelpers.toTimestamp(overrideDate)). groupBy(countRow.originalValue). compute(count(countRow.originalValue)). on(queryRow.id === resultRow.queryId, resultRow.id === countRow.resultId) } //Filter for result counts > 0 from(counts) { cnt => where(cnt.measures gt 0).select(cnt.measures) } } val queriesForAllUsers: Query[SquerylShrineQuery] = { from(tables.shrineQueries) { queryRow => select(queryRow).orderBy(queryRow.dateCreated.desc) } } //TODO: Find a way to parameterize on limit, to avoid building the query every time //TODO: limit def queriesForUser(username: String, domain: String): Query[SquerylShrineQuery] = { from(tables.shrineQueries) { queryRow => where(queryRow.domain === domain and queryRow.username === username). select(queryRow). orderBy(queryRow.dateCreated.desc) } } def queriesForDomain(domain: String): Query[SquerylShrineQuery] = { from(tables.shrineQueries) { queryRow => where(queryRow.domain === domain). select(queryRow). orderBy(queryRow.dateCreated.desc) } } val allCountResults: Query[SquerylCountRow] = { from(tables.countResults) { queryRow => select(queryRow) } } def queriesByNetworkId(networkQueryId: Long): Query[SquerylShrineQuery] = { from(tables.shrineQueries) { queryRow => where(queryRow.networkId === networkQueryId).select(queryRow) } } //TODO: Find out how to compose queries, to re-use queriesByNetworkId def queryNamesByNetworkId(networkQueryId: Long): Query[String] = { from(tables.shrineQueries) { queryRow => where(queryRow.networkId === networkQueryId).select(queryRow.name) } } def resultsForQuery(networkQueryId: Long): Query[SquerylQueryResultRow] = { val resultsForNetworkQueryId = join(tables.shrineQueries, tables.queryResults) { (queryRow, resultRow) => where(queryRow.networkId === networkQueryId). select(resultRow). on(queryRow.id === resultRow.queryId) } from(resultsForNetworkQueryId)(select(_)) } def countResults(networkQueryId: Long): Query[SquerylCountRow] = { join(tables.shrineQueries, tables.queryResults, tables.countResults) { (queryRow, resultRow, countRow) => where(queryRow.networkId === networkQueryId). select(countRow). on(queryRow.id === resultRow.queryId, resultRow.id === countRow.resultId) } } def errorResults(networkQueryId: Long): Query[SquerylShrineError] = { join(tables.shrineQueries, tables.queryResults, tables.errorResults) { (queryRow, resultRow, errorRow) => where(queryRow.networkId === networkQueryId). select(errorRow). on(queryRow.id === resultRow.queryId, resultRow.id === errorRow.resultId) } } //NB: using groupBy here is too much of a pain; do it 'manually' later def breakdownResults(networkQueryId: Long): Query[(SquerylQueryResultRow, SquerylBreakdownResultRow)] = { join(tables.shrineQueries, tables.queryResults, tables.breakdownResults) { (queryRow, resultRow, breakdownRow) => where(queryRow.networkId === networkQueryId). select((resultRow, breakdownRow)). on(queryRow.id === resultRow.queryId, resultRow.id === breakdownRow.resultId) } } } -} \ No newline at end of file +} diff --git a/adapter/adapter-service/src/test/scala/net/shrine/adapter/AbstractQueryRetrievalTestCase.scala b/adapter/adapter-service/src/test/scala/net/shrine/adapter/AbstractQueryRetrievalTestCase.scala index 6ee1c050d..44fb9d53d 100644 --- a/adapter/adapter-service/src/test/scala/net/shrine/adapter/AbstractQueryRetrievalTestCase.scala +++ b/adapter/adapter-service/src/test/scala/net/shrine/adapter/AbstractQueryRetrievalTestCase.scala @@ -1,378 +1,379 @@ package net.shrine.adapter import scala.xml.NodeSeq import net.shrine.util.ShouldMatchersForJUnit import ObfuscatorTest.within3 import javax.xml.datatype.XMLGregorianCalendar import net.shrine.adapter.dao.AdapterDao import net.shrine.adapter.dao.squeryl.AbstractSquerylAdapterTest import net.shrine.client.HttpClient import net.shrine.client.HttpResponse import net.shrine.protocol.{AuthenticationInfo, BaseShrineRequest, BaseShrineResponse, BroadcastMessage, CrcRequest, Credential, DefaultBreakdownResultOutputTypes, ErrorResponse, HiveCredentials, I2b2ResultEnvelope, QueryResult, ReadResultRequest, ReadResultResponse, ResultOutputType, RunQueryRequest, RunQueryResponse, ShrineRequest, ShrineResponse} import net.shrine.protocol.DefaultBreakdownResultOutputTypes.PATIENT_AGE_COUNT_XML import net.shrine.protocol.ResultOutputType.PATIENT_COUNT_XML import net.shrine.protocol.DefaultBreakdownResultOutputTypes.PATIENT_GENDER_COUNT_XML import net.shrine.protocol.query.{QueryDefinition, Term} import net.shrine.util.XmlDateHelper import net.shrine.util.XmlDateHelper.now import net.shrine.util.XmlGcEnrichments import net.shrine.client.Poster import net.shrine.adapter.translators.QueryDefinitionTranslator import net.shrine.adapter.translators.ExpressionTranslator import net.shrine.problem.TestProblem import scala.util.Success /** * @author clint * @since Nov 8, 2012 */ //noinspection UnitMethodIsParameterless abstract class AbstractQueryRetrievalTestCase[R <: BaseShrineResponse]( makeAdapter: (AdapterDao, HttpClient) => WithHiveCredentialsAdapter, makeRequest: (Long, AuthenticationInfo) => BaseShrineRequest, extractor: R => Option[(Long, QueryResult)]) extends AbstractSquerylAdapterTest with ShouldMatchersForJUnit { private val authn = AuthenticationInfo("some-domain", "some-user", Credential("alskdjlkasd", false)) def doTestProcessRequestMissingQuery { val adapter = makeAdapter(dao, MockHttpClient) val response = adapter.processRequest(BroadcastMessage(0L, authn, makeRequest(-1L, authn))) response.isInstanceOf[ErrorResponse] should be(true) } def doTestProcessInvalidRequest { val adapter = makeAdapter(dao, MockHttpClient) intercept[ClassCastException] { //request must be a type of request we can handle adapter.processRequest(BroadcastMessage(0L, authn, new AbstractQueryRetrievalTestCase.BogusRequest)) } } private val localMasterId = "alksjdkalsdjlasdjlkjsad" private val shrineNetworkQueryId = 123L private def doGetResults(adapter: Adapter) = adapter.processRequest(BroadcastMessage(shrineNetworkQueryId, authn, makeRequest(shrineNetworkQueryId, authn))) private def toMillis(xmlGc: XMLGregorianCalendar): Long = xmlGc.toGregorianCalendar.getTimeInMillis private val instanceId = 999L private val setSize = 12345L private val obfSetSize = setSize + 1 private val queryExpr = Term("foo") private val topicId = "laskdjlkasd" private val fooQuery = QueryDefinition("some-query",queryExpr) def doTestProcessRequestIncompleteQuery(countQueryShouldWork: Boolean = true): Unit = afterCreatingTables { val dbQueryId = dao.insertQuery(localMasterId, shrineNetworkQueryId, authn, fooQuery, isFlagged = false, hasBeenRun = true, flagMessage = None) import ResultOutputType._ import XmlDateHelper.now val breakdowns = Map(PATIENT_AGE_COUNT_XML -> I2b2ResultEnvelope(PATIENT_AGE_COUNT_XML, Map("a" -> 1L, "b" -> 2L))) val obfscBreakdowns = breakdowns.mapValues(_.mapValues(_ + 1)) val startDate = now val elapsed = 100L val endDate = { import XmlGcEnrichments._ import scala.concurrent.duration._ startDate + elapsed.milliseconds } val countResultId = 456L val breakdownResultId = 98237943265436L val incompleteCountResult = QueryResult( resultId = countResultId, instanceId = instanceId, resultType = Some(PATIENT_COUNT_XML), setSize = setSize, startDate = Option(startDate), endDate = Option(endDate), description = Some("results from node X"), statusType = QueryResult.StatusType.Processing, statusMessage = None, breakdowns = breakdowns) val breakdownResult = breakdowns.head match { case (resultType, data) => incompleteCountResult.withId(breakdownResultId).withBreakdowns(Map(resultType -> data)).withResultType(resultType) } val queryStartDate = now val idsByResultType = dao.insertQueryResults(dbQueryId, incompleteCountResult :: breakdownResult :: Nil) final class MightWorkMockHttpClient(expectedHiveCredentials: HiveCredentials) extends HttpClient { override def post(input: String, url: String): HttpResponse = { def makeFinished(queryResult: QueryResult) = queryResult.copy(statusType = QueryResult.StatusType.Finished) def validateAuthnAndProjectId(req: ShrineRequest) { req.authn should equal(expectedHiveCredentials.toAuthenticationInfo) req.projectId should equal(expectedHiveCredentials.projectId) } val response = CrcRequest.fromI2b2String(DefaultBreakdownResultOutputTypes.toSet)(input) match { case Success(req: ReadResultRequest) if req.localResultId == countResultId.toString => { validateAuthnAndProjectId(req) if (countQueryShouldWork) { ReadResultResponse(123L, makeFinished(incompleteCountResult), I2b2ResultEnvelope(PATIENT_COUNT_XML, Map(PATIENT_COUNT_XML.name -> incompleteCountResult.setSize))) } else { ErrorResponse(TestProblem(summary = "Retrieving count result failed")) } } case Success(req: ReadResultRequest) if req.localResultId == breakdownResultId.toString => { validateAuthnAndProjectId(req) ReadResultResponse(123L, makeFinished(breakdownResult), breakdowns.head._2) } case _ => fail(s"Unknown input: $input") } HttpResponse.ok(response.toI2b2String) } } val adapter: WithHiveCredentialsAdapter = makeAdapter(dao, new MightWorkMockHttpClient(AbstractQueryRetrievalTestCase.hiveCredentials)) def getResults = doGetResults(adapter) getResults.isInstanceOf[ErrorResponse] should be(true) dao.insertCountResult(idsByResultType(PATIENT_COUNT_XML).head, setSize, obfSetSize) dao.insertBreakdownResults(idsByResultType, breakdowns, obfscBreakdowns) //The query shouldn't be 'done', since its status is PROCESSING dao.findResultsFor(shrineNetworkQueryId).get.count.statusType should be(QueryResult.StatusType.Processing) //Now, calling processRequest (via getResults) should cause the query to be re-retrieved from the CRC val result = getResults.asInstanceOf[R] //Which should cause the query to be re-stored with a 'done' status (since that's what our mock CRC returns) val expectedStatusType = if (countQueryShouldWork) QueryResult.StatusType.Finished else QueryResult.StatusType.Processing dao.findResultsFor(shrineNetworkQueryId).get.count.statusType should be(expectedStatusType) if (!countQueryShouldWork) { result.isInstanceOf[ErrorResponse] should be(true) } else { val Some((actualNetworkQueryId, actualQueryResult)) = extractor(result) actualNetworkQueryId should equal(shrineNetworkQueryId) import ObfuscatorTest.within3 actualQueryResult.resultType should equal(Some(PATIENT_COUNT_XML)) within3(setSize, actualQueryResult.setSize) should be(true) actualQueryResult.description should be(Some("results from node X")) actualQueryResult.statusType should equal(QueryResult.StatusType.Finished) actualQueryResult.statusMessage should be(Some(QueryResult.StatusType.Finished.name)) actualQueryResult.breakdowns.foreach { case (rt, I2b2ResultEnvelope(_, data)) => { data.forall { case (key, value) => within3(value, breakdowns.get(rt).get.data.get(key).get) } } } for { startDate <- actualQueryResult.startDate endDate <- actualQueryResult.endDate } { val actualElapsed = toMillis(endDate) - toMillis(startDate) actualElapsed should equal(elapsed) } } } def doTestProcessRequestQueuedQuery: Unit = afterCreatingTables { import ResultOutputType._ import XmlDateHelper.now val startDate = now val elapsed = 100L val endDate = { import XmlGcEnrichments._ import scala.concurrent.duration._ startDate + elapsed.milliseconds } val countResultId = 456L val incompleteCountResult = QueryResult(-1L, -1L, Some(PATIENT_COUNT_XML), -1L, Option(startDate), Option(endDate), Some("results from node X"), QueryResult.StatusType.Queued, None) dao.inTransaction { val insertedQueryId = dao.insertQuery(localMasterId, shrineNetworkQueryId, authn, fooQuery, isFlagged = false, hasBeenRun = false, flagMessage = None) //NB: We need to insert dummy QueryResult and Count records so that calls to StoredQueries.retrieve() in //AbstractReadQueryResultAdapter, called when retrieving results for previously-queued-or-incomplete //queries, will work. val insertedQueryResultIds = dao.insertQueryResults(insertedQueryId, Seq(incompleteCountResult)) val countQueryResultId = insertedQueryResultIds(ResultOutputType.PATIENT_COUNT_XML).head dao.insertCountResult(countQueryResultId, -1L, -1L) } val queryStartDate = now object MockHttpClient extends HttpClient { override def post(input: String, url: String): HttpResponse = ??? } val adapter: WithHiveCredentialsAdapter = makeAdapter(dao, MockHttpClient) def getResults = doGetResults(adapter) getResults.isInstanceOf[ErrorResponse] should be(true) //The query shouldn't be 'done', since its status is QUEUED dao.findResultsFor(shrineNetworkQueryId).get.count.statusType should be(QueryResult.StatusType.Queued) //Now, calling processRequest (via getResults) should NOT cause the query to be re-retrieved from the CRC, because the query was previously queued val result = getResults result.isInstanceOf[ErrorResponse] should be(true) dao.findResultsFor(shrineNetworkQueryId).get.count.statusType should be(QueryResult.StatusType.Queued) } def doTestProcessRequest = afterCreatingTables { val adapter = makeAdapter(dao, MockHttpClient) def getResults = doGetResults(adapter) getResults match { case errorResponse:ErrorResponse => errorResponse.problemDigest.codec should be (classOf[QueryNotFound].getName) case x => fail(s"Got $x, not an ErrorResponse") } val dbQueryId = dao.insertQuery(localMasterId, shrineNetworkQueryId, authn, fooQuery, isFlagged = false, hasBeenRun = false, flagMessage = None) getResults match { case errorResponse:ErrorResponse => errorResponse.problemDigest.codec should be (classOf[QueryResultNotAvailable].getName) case x => fail(s"Got $x, not an ErrorResponse") } import ResultOutputType._ import XmlDateHelper.now val breakdowns = Map( PATIENT_AGE_COUNT_XML -> I2b2ResultEnvelope(PATIENT_AGE_COUNT_XML, Map("a" -> 1L, "b" -> 2L)), PATIENT_GENDER_COUNT_XML -> I2b2ResultEnvelope(PATIENT_GENDER_COUNT_XML, Map("x" -> 3L, "y" -> 4L))) val obfscBreakdowns = breakdowns.mapValues(_.mapValues(_ + 1)) val startDate = now val elapsed = 100L val endDate = { import XmlGcEnrichments._ import scala.concurrent.duration._ startDate + elapsed.milliseconds } val countResult = QueryResult( resultId = 456L, instanceId = instanceId, resultType = Some(PATIENT_COUNT_XML), setSize = setSize, startDate = Option(startDate), endDate = Option(endDate), description = Some("results from node X"), statusType = QueryResult.StatusType.Finished, statusMessage = None, breakdowns = breakdowns ) val breakdownResults = breakdowns.map { case (resultType, data) => countResult.withBreakdowns(Map(resultType -> data)).withResultType(resultType) }.toSeq val queryStartDate = now val idsByResultType = dao.insertQueryResults(dbQueryId, countResult +: breakdownResults) getResults.isInstanceOf[ErrorResponse] should be(true) dao.insertCountResult(idsByResultType(PATIENT_COUNT_XML).head, setSize, obfSetSize) dao.insertBreakdownResults(idsByResultType, breakdowns, obfscBreakdowns) val result = getResults.asInstanceOf[R] val Some((actualNetworkQueryId, actualQueryResult)) = extractor(result) actualNetworkQueryId should equal(shrineNetworkQueryId) actualQueryResult.resultType should equal(Some(PATIENT_COUNT_XML)) actualQueryResult.setSize should equal(obfSetSize) actualQueryResult.description should be(None) //TODO: This is probably wrong actualQueryResult.statusType should equal(QueryResult.StatusType.Finished) actualQueryResult.statusMessage should be(None) actualQueryResult.breakdowns should equal(obfscBreakdowns) for { startDate <- actualQueryResult.startDate endDate <- actualQueryResult.endDate } { val actualElapsed = toMillis(endDate) - toMillis(startDate) actualElapsed should equal(elapsed) } } } object AbstractQueryRetrievalTestCase { val hiveCredentials = HiveCredentials("some-hive-domain", "hive-username", "hive-password", "hive-project") val doObfuscation = true def runQueryAdapter(dao: AdapterDao, poster: Poster): RunQueryAdapter = { val translator = new QueryDefinitionTranslator(new ExpressionTranslator(Map("foo" -> Set("bar")))) new RunQueryAdapter( - poster, - dao, - AbstractQueryRetrievalTestCase.hiveCredentials, - translator, - 10000, - doObfuscation, + poster = poster, + dao = dao, + hiveCredentials = AbstractQueryRetrievalTestCase.hiveCredentials, + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = 10000, + doObfuscation = doObfuscation, runQueriesImmediately = true, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty ) } import scala.concurrent.duration._ final class BogusRequest extends ShrineRequest("fooProject", 1.second, null) { override val requestType = null protected override def i2b2MessageBody: NodeSeq = override def toXml = } } \ No newline at end of file diff --git a/adapter/adapter-service/src/test/scala/net/shrine/adapter/RunQueryAdapterTest.scala b/adapter/adapter-service/src/test/scala/net/shrine/adapter/RunQueryAdapterTest.scala index 73178795b..70a858a17 100644 --- a/adapter/adapter-service/src/test/scala/net/shrine/adapter/RunQueryAdapterTest.scala +++ b/adapter/adapter-service/src/test/scala/net/shrine/adapter/RunQueryAdapterTest.scala @@ -1,966 +1,974 @@ package net.shrine.adapter import scala.concurrent.duration.DurationInt import org.junit.Test import net.shrine.util.ShouldMatchersForJUnit import ObfuscatorTest.within3 import net.shrine.adapter.dao.AdapterDao import net.shrine.adapter.dao.squeryl.AbstractSquerylAdapterTest import net.shrine.adapter.translators.ExpressionTranslator import net.shrine.adapter.translators.QueryDefinitionTranslator import net.shrine.client.HttpClient import net.shrine.client.HttpResponse import net.shrine.client.Poster import net.shrine.protocol.{AuthenticationInfo, BaseShrineResponse, BroadcastMessage, CrcRequest, Credential, DefaultBreakdownResultOutputTypes, ErrorResponse, HiveCredentials, I2b2ResultEnvelope, QueryResult, RawCrcRunQueryResponse, ReadResultRequest, ReadResultResponse, ResultOutputType, RunQueryRequest, RunQueryResponse} import net.shrine.protocol.RawCrcRunQueryResponse.toQueryResultMap import net.shrine.protocol.DefaultBreakdownResultOutputTypes.PATIENT_AGE_COUNT_XML import net.shrine.protocol.ResultOutputType.PATIENT_COUNT_XML import net.shrine.protocol.DefaultBreakdownResultOutputTypes.PATIENT_GENDER_COUNT_XML import net.shrine.protocol.DefaultBreakdownResultOutputTypes.PATIENT_RACE_COUNT_XML import net.shrine.protocol.query.OccuranceLimited import net.shrine.protocol.query.Or import net.shrine.protocol.query.QueryDefinition import net.shrine.protocol.query.Term import net.shrine.util.XmlDateHelper import net.shrine.util.XmlUtil import scala.util.Success import net.shrine.dao.squeryl.SquerylEntryPoint import scala.concurrent.duration.Duration import net.shrine.adapter.dao.model.ShrineError import net.shrine.adapter.dao.model.QueryResultRow import net.shrine.problem.TestProblem /** * @author Bill Simons * @author Clint Gilbert * @since 4/19/11 * @see http://cbmi.med.harvard.edu */ final class RunQueryAdapterTest extends AbstractSquerylAdapterTest with ShouldMatchersForJUnit { private val queryDef = QueryDefinition("foo", Term("foo")) private val broadcastMessageId = 1234563789L private val queryId = 123L private val expectedNetworkQueryId = 999L private val expectedLocalMasterId = queryId.toString private val masterId = 99L private val instanceId = 456L private val resultId = 42L private val projectId = "projectId" private val setSize = 17L private val xmlResultId = 98765L private val userId = "userId" private val groupId = "groupId" private val topicId = "some-topic-id-123-foo" private val topicName = "Topic Name" private val justCounts = Set(PATIENT_COUNT_XML) private val now = XmlDateHelper.now private val countQueryResult = QueryResult(resultId, instanceId, Some(PATIENT_COUNT_XML), setSize, Some(now), Some(now), None, QueryResult.StatusType.Finished, None) private val dummyBreakdownData = Map("x" -> 99L, "y" -> 42L, "z" -> 3000L) private val hiveCredentials = HiveCredentials("some-hive-domain", "hive-username", "hive-password", "hive-project") private val authn = AuthenticationInfo("some-domain", "username", Credential("jksafhkjaf", false)) private val adapterLockoutThreshold = 99 private val altI2b2ErrorXml = XmlUtil.stripWhitespace { 1.1 2.4 edu.harvard.i2b2.crc 1.5 i2b2 Hive i2b2_QueryTool 0.2 i2b2 Hive 1 i2b2 Log information DONE Query result instance id 3126 not found }.toString private val otherNetworkId: Long = 12345L @Test def testProcessRawCrcRunQueryResponseCountQueryOnly: Unit = afterCreatingTables{ val outputTypes = Set(PATIENT_COUNT_XML) val translator = new QueryDefinitionTranslator(new ExpressionTranslator(Map("network" -> Set("local1a", "local1b")))) val adapter = new RunQueryAdapter( Poster("crc-url", null), dao, hiveCredentials, translator, adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = true, breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty //todo this might be the right place to test bot defense ) val request = RunQueryRequest(projectId, 1.second, authn, expectedNetworkQueryId, Option(topicId), Option(topicName), outputTypes, queryDef) val networkAuthn = AuthenticationInfo("some-domain", "username", Credential("sadasdasdasd", false)) val broadcastMessage = BroadcastMessage(queryId, networkAuthn, request) val rawRunQueryResponse = RawCrcRunQueryResponse( queryId = queryId, createDate = XmlDateHelper.now, userId = request.authn.username, groupId = request.authn.domain, requestXml = request.queryDefinition, queryInstanceId = otherNetworkId, singleNodeResults = toQueryResultMap(Seq(countQueryResult))) val resp = adapter.processRawCrcRunQueryResponse(networkAuthn, request, rawRunQueryResponse).asInstanceOf[RunQueryResponse] resp should not be (null) //Validate the response resp.createDate should not be(null) resp.groupId should be(request.authn.domain) resp.userId should be(request.authn.username) resp.queryId should be(queryId) resp.queryInstanceId should be(otherNetworkId) resp.requestXml should equal(request.queryDefinition) (countQueryResult eq resp.singleNodeResult) should be(false) within3(resp.singleNodeResult.setSize, countQueryResult.setSize) should be(true) resp.singleNodeResult.resultType.get should equal(PATIENT_COUNT_XML) resp.singleNodeResult.breakdowns should equal(Map.empty) //validate the DB val expectedNetworkTerm = queryDef.expr.get.asInstanceOf[Term] //We should have one row in the shrine_query table, for the query just performed val Seq(queryRow) = list(queryRows) { queryRow.dateCreated should not be (null) queryRow.domain should equal(request.authn.domain) queryRow.name should equal(queryDef.name) queryRow.localId should equal(expectedLocalMasterId) queryRow.networkId should equal(expectedNetworkQueryId) queryRow.username should equal(authn.username) queryRow.queryDefinition.expr.get should equal(expectedNetworkTerm) } //We should have one row in the count_result table, with the right obfuscated value, which is within the expected amount from the original count val Seq(countRow) = list(countResultRows) { countRow.creationDate should not be (null) countRow.originalValue should equal(countQueryResult.setSize) within3(countRow.obfuscatedValue, countRow.originalValue) should be(true) } } @Test def testProcessRawCrcRunQueryResponseCountAndBreakdownQuery: Unit = afterCreatingTables { val allBreakdownTypes = DefaultBreakdownResultOutputTypes.toSet val breakdownTypes = Seq(PATIENT_GENDER_COUNT_XML) val outputTypes = Set(PATIENT_COUNT_XML) ++ breakdownTypes val translator = new QueryDefinitionTranslator(new ExpressionTranslator(Map("network" -> Set("local1a", "local1b")))) val request = RunQueryRequest(projectId, 1.second, authn, expectedNetworkQueryId, Option(topicId), Option(topicName), outputTypes, queryDef) val networkAuthn = AuthenticationInfo("some-domain", "username", Credential("sadasdasdasd", false)) val broadcastMessage = BroadcastMessage(queryId, networkAuthn, request) val breakdownQueryResults = breakdownTypes.zipWithIndex.map { case (rt, i) => countQueryResult.withId(resultId + i + 1).withResultType(rt) } val singleNodeResults = toQueryResultMap(countQueryResult +: breakdownQueryResults) val rawRunQueryResponse = RawCrcRunQueryResponse( queryId = queryId, createDate = XmlDateHelper.now, userId = request.authn.username, groupId = request.authn.domain, requestXml = request.queryDefinition, queryInstanceId = otherNetworkId, singleNodeResults = singleNodeResults) //Set up our mock CRC val poster = Poster("crc-url", new HttpClient { def post(input: String, url: String): HttpResponse = HttpResponse.ok { (RunQueryRequest.fromI2b2String(allBreakdownTypes)(input) orElse ReadResultRequest.fromI2b2String(allBreakdownTypes)(input)).get match { case runQueryReq: RunQueryRequest => rawRunQueryResponse.toI2b2String case readResultReq: ReadResultRequest => ReadResultResponse(xmlResultId = 42L, metadata = breakdownQueryResults.head, data = I2b2ResultEnvelope(PATIENT_GENDER_COUNT_XML, dummyBreakdownData)).toI2b2String case _ => sys.error(s"Unknown request: '$input'") //Fail loudly } } }) - val adapter = new RunQueryAdapter( - poster, - dao, - hiveCredentials, - translator, - adapterLockoutThreshold, + val adapter = RunQueryAdapter( + poster = poster, + dao = dao, + hiveCredentials = hiveCredentials, + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = true, breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty //todo this might be the right place to test bot defense ) val resp = adapter.processRawCrcRunQueryResponse(networkAuthn, request, rawRunQueryResponse).asInstanceOf[RunQueryResponse] resp should not be (null) //Validate the response resp.createDate should not be(null) resp.groupId should be(request.authn.domain) resp.userId should be(request.authn.username) resp.queryId should be(queryId) resp.queryInstanceId should be(otherNetworkId) resp.requestXml should equal(request.queryDefinition) (countQueryResult eq resp.singleNodeResult) should be(false) within3(resp.singleNodeResult.setSize, countQueryResult.setSize) should be(true) resp.singleNodeResult.resultType.get should equal(PATIENT_COUNT_XML) resp.singleNodeResult.breakdowns.keySet should equal(Set(PATIENT_GENDER_COUNT_XML)) val breakdownEnvelope = resp.singleNodeResult.breakdowns.values.head breakdownEnvelope.resultType should equal(PATIENT_GENDER_COUNT_XML) breakdownEnvelope.data.keySet should equal(dummyBreakdownData.keySet) //All breakdowns are obfuscated for { (key, value) <- breakdownEnvelope.data } { within3(value, dummyBreakdownData(key)) should be(true) } //validate the DB val expectedNetworkTerm = queryDef.expr.get.asInstanceOf[Term] //We should have one row in the shrine_query table, for the query just performed val Seq(queryRow) = list(queryRows) { queryRow.dateCreated should not be (null) queryRow.domain should equal(request.authn.domain) queryRow.name should equal(queryDef.name) queryRow.localId should equal(expectedLocalMasterId) queryRow.networkId should equal(expectedNetworkQueryId) queryRow.username should equal(authn.username) queryRow.queryDefinition.expr.get should equal(expectedNetworkTerm) } //We should have one row in the count_result table, with the right obfuscated value, which is within the expected amount from the original count val Seq(countRow) = list(countResultRows) { countRow.creationDate should not be (null) countRow.originalValue should equal(countQueryResult.setSize) within3(countRow.obfuscatedValue, countRow.originalValue) should be(true) } val breakdownRows @ Seq(xRow, yRow, zRow) = list(breakdownResultRows) breakdownRows.map(_.dataKey).toSet should equal(dummyBreakdownData.keySet) within3(xRow.obfuscatedValue, xRow.originalValue) should be(true) xRow.originalValue should be(dummyBreakdownData(xRow.dataKey)) within3(yRow.obfuscatedValue, yRow.originalValue) should be(true) yRow.originalValue should be(dummyBreakdownData(yRow.dataKey)) within3(zRow.obfuscatedValue, zRow.originalValue) should be(true) zRow.originalValue should be(dummyBreakdownData(zRow.dataKey)) } //NB: See https://open.med.harvard.edu/jira/browse/SHRINE-745 @Test def testParseAltErrorXml { - val adapter = new RunQueryAdapter( - Poster("crc-url", null), - null, - hiveCredentials, - null, - adapterLockoutThreshold, + val adapter = RunQueryAdapter( + poster = Poster("crc-url", null), + dao = null, + hiveCredentials = hiveCredentials, + conceptTranslator = null, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = false, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty ) val resp: ErrorResponse = adapter.parseShrineErrorResponseWithFallback(altI2b2ErrorXml).asInstanceOf[ErrorResponse] resp should not be (null) resp.errorMessage should be("Query result instance id 3126 not found") } @Test def testParseErrorXml { val xml = { 1.1 2.4 edu.harvard.i2b2.crc 1.4 i2b2 Hive i2b2web 1.4 i2b2 Hive 1 Demo Log information Message error connecting Project Management cell admin 0 0 CRC_QRY_runQueryInstance_fromQueryDefinition Age 0 1 0 0 1 2 Age \\i2b2\i2b2\Demographics\Age\ concept_dimension concept_path \i2b2\Demographics\Age\ T concept_cd false }.toString - val adapter = new RunQueryAdapter( - Poster("crc-url", null), - null, - hiveCredentials, - null, - adapterLockoutThreshold, + val adapter = RunQueryAdapter( + poster = Poster("crc-url", null), + dao = null, + hiveCredentials = hiveCredentials, + conceptTranslator = null, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = true, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty //todo this might be the right place to test bot defense ) val resp = adapter.parseShrineErrorResponseWithFallback(xml).asInstanceOf[ErrorResponse] resp should not be (null) resp.errorMessage should not be ("") } @Test def testObfuscateBreakdowns { val breakdown1 = I2b2ResultEnvelope(PATIENT_AGE_COUNT_XML, Map.empty) val breakdown2 = I2b2ResultEnvelope(PATIENT_GENDER_COUNT_XML, Map("foo" -> 123, "bar" -> 345)) val breakdown3 = I2b2ResultEnvelope(PATIENT_RACE_COUNT_XML, Map("x" -> 999, "y" -> 888)) val original = Map.empty ++ Seq(breakdown1, breakdown2, breakdown3).map(env => (env.resultType, env)) val obfuscated = RunQueryAdapter.obfuscateBreakdowns(original) original.keySet should equal(obfuscated.keySet) original.keySet.forall(resultType => original(resultType).data.keySet == obfuscated(resultType).data.keySet) should be(true) val localTerms = Set("local1a", "local1b") for { (resultType, origBreakdown) <- original mappings = Map("network" -> localTerms) translator = new QueryDefinitionTranslator(new ExpressionTranslator(mappings)) obfscBreakdown <- obfuscated.get(resultType) key <- origBreakdown.data.keySet } { (origBreakdown eq obfscBreakdown) should be(false) ObfuscatorTest.within3(origBreakdown.data(key), obfscBreakdown.data(key)) should be(true) } } @Test def testTranslateNetworkToLocalDoesntLeakCredentialsViaException: Unit = { val mappings = Map.empty[String, Set[String]] val translator = new QueryDefinitionTranslator(new ExpressionTranslator(mappings)) - val adapter = new RunQueryAdapter( - Poster("crc-url", MockHttpClient), - null, - null, - translator, - adapterLockoutThreshold, + val adapter = RunQueryAdapter( + poster = Poster("crc-url", MockHttpClient), + dao = null, + hiveCredentials = null, + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = true, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty ) val queryDefinition = QueryDefinition("foo", Term("blah")) val authn = AuthenticationInfo("d", "u", Credential("p", false)) val req = RunQueryRequest("projectId", Duration.Inf, authn, otherNetworkId, None, None, Set.empty, queryDef) try { adapter.translateNetworkToLocal(req) fail("Expected an AdapterMappingException") } catch { case e: AdapterMappingException => { e.getMessage.contains(authn.rawToString) should be(false) e.getMessage.contains(AuthenticationInfo.elided.toString) should be(true) } } } @Test def testTranslateQueryDefinitionXml { val localTerms = Set("local1a", "local1b") val mappings = Map("network" -> localTerms) val translator = new QueryDefinitionTranslator(new ExpressionTranslator(mappings)) - val adapter = new RunQueryAdapter( - Poster("crc-url", MockHttpClient), - null, - null, - translator, - adapterLockoutThreshold, + val adapter = RunQueryAdapter( + poster = Poster("crc-url", MockHttpClient), + dao = null, + hiveCredentials = null, + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = true, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty //todo this might be the right place to test bot defense ) val queryDefinition = QueryDefinition("10-17 years old@14:39:20", OccuranceLimited(1, Term("network"))) val newDef = adapter.conceptTranslator.translate(queryDefinition) val expected = QueryDefinition("10-17 years old@14:39:20", Or(Term("local1a"), Term("local1b"))) newDef should equal(expected) } @Test def testQueuedRegularCountQuery: Unit = afterCreatingTables { val adapter = RunQueryAdapter( - Poster("crc-url", MockHttpClient), - dao, - null, - null, - adapterLockoutThreshold, + poster = Poster("crc-url", MockHttpClient), + dao = dao, + hiveCredentials = null, + conceptTranslator = null, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = false, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty //todo this might be the right place to test bot defense ) val networkAuthn = AuthenticationInfo("nd", "nu", Credential("np", false)) import scala.concurrent.duration._ val req = RunQueryRequest(projectId, 1.second, authn, expectedNetworkQueryId, Option(topicId), Option(topicName), Set(PATIENT_COUNT_XML), queryDef) val broadcastMessage = BroadcastMessage(queryId, networkAuthn, req) val resp = adapter.processRequest(broadcastMessage).asInstanceOf[RunQueryResponse] resp.groupId should equal(networkAuthn.domain) resp.createDate should not be (null) // :\ resp.queryId should equal(-1L) resp.queryInstanceId should equal(-1L) resp.requestXml should equal(queryDef) resp.userId should equal(networkAuthn.username) resp.singleNodeResult.breakdowns should equal(Map.empty) resp.singleNodeResult.description.isDefined should be(true) resp.singleNodeResult.elapsed should equal(Some(0L)) resp.singleNodeResult.endDate.isDefined should be(true) resp.singleNodeResult.startDate.isDefined should be(true) resp.singleNodeResult.instanceId should equal(-1L) resp.singleNodeResult.isError should be(false) resp.singleNodeResult.resultId should equal(-1L) resp.singleNodeResult.resultType should be(Some(PATIENT_COUNT_XML)) resp.singleNodeResult.setSize should equal(-1L) resp.singleNodeResult.statusMessage.isDefined should be(true) resp.singleNodeResult.statusType should be(QueryResult.StatusType.Held) resp.singleNodeResult.endDate.isDefined should be(true) val Some(storedQuery) = dao.findQueryByNetworkId(expectedNetworkQueryId) storedQuery.dateCreated should not be (null) // :\ storedQuery.domain should equal(networkAuthn.domain) storedQuery.isFlagged should equal(false) storedQuery.localId should equal(-1L.toString) storedQuery.name should equal(queryDef.name) storedQuery.networkId should equal(expectedNetworkQueryId) storedQuery.queryDefinition should equal(queryDef) storedQuery.username should equal(networkAuthn.username) } private def doTestRegularCountQuery(status: QueryResult.StatusType, count: Long) = afterCreatingTables { require(!status.isError) val countQueryResultToUse = countQueryResult.copy(statusType = status, setSize = count) val outputTypes = justCounts val resp = doQuery(outputTypes) { import RawCrcRunQueryResponse.toQueryResultMap RawCrcRunQueryResponse(queryId, now, userId, groupId, queryDef, instanceId, toQueryResultMap(Seq(countQueryResultToUse))).toI2b2String }.asInstanceOf[RunQueryResponse] doBasicRunQueryResponseTest(resp) val firstResult = resp.results.head resp.results should equal(Seq(firstResult)) val Some(savedQuery) = dao.findResultsFor(expectedNetworkQueryId) savedQuery.wasRun should equal(true) savedQuery.isFlagged should equal(false) savedQuery.networkQueryId should equal(expectedNetworkQueryId) savedQuery.breakdowns should equal(Nil) savedQuery.count.creationDate should not be (null) savedQuery.count.localId should equal(countQueryResultToUse.resultId) //savedQuery.count.resultId should equal(resultId) TODO: REVISIT savedQuery.count.statusType should equal(status) if (status.isDone && !status.isError) { savedQuery.count.data.get.startDate should not be (null) savedQuery.count.data.get.endDate should not be (null) savedQuery.count.data.get.originalValue should be(count) ObfuscatorTest.within3(savedQuery.count.data.get.obfuscatedValue, count) should be(true) } else { savedQuery.count.data should be(None) } } @Test def testRegularCountQuery = doTestRegularCountQuery(QueryResult.StatusType.Finished, countQueryResult.setSize) @Test def testRegularCountQueryComesBackProcessing = doTestRegularCountQuery(QueryResult.StatusType.Processing, -1L) @Test def testRegularCountQueryComesBackQueued = doTestRegularCountQuery(QueryResult.StatusType.Queued, -1L) @Test def testRegularCountQueryComesBackError = afterCreatingTables { val errorQueryResult = QueryResult.errorResult(Some("some-description"), "some-status-message",TestProblem()) val outputTypes = justCounts val resp = doQuery(outputTypes) { import RawCrcRunQueryResponse.toQueryResultMap RawCrcRunQueryResponse(queryId, now, userId, groupId, queryDef, instanceId, toQueryResultMap(Seq(errorQueryResult))).toI2b2String } doBasicRunQueryResponseTest(resp) //TODO: Why are status and description messages from CRC dropped when unmarshalling QueryResults? //resp.results should equal(Seq(errorQueryResult)) resp.asInstanceOf[RunQueryResponse].results.head.statusType should be(QueryResult.StatusType.Error) dao.findResultsFor(expectedNetworkQueryId) should be(None) val Some(savedQueryRow) = dao.findQueryByNetworkId(expectedNetworkQueryId) val Seq(queryResultRow: QueryResultRow) = { import SquerylEntryPoint._ implicit val breakdownTypes = DefaultBreakdownResultOutputTypes.toSet inTransaction { from(tables.queryResults) { row => where(row.queryId === savedQueryRow.id). select(row.toQueryResultRow) }.toSeq } } val Seq(errorRow: ShrineError) = { import SquerylEntryPoint._ inTransaction { from(tables.errorResults) { row => where(row.resultId === queryResultRow.id). select(row.toShrineError) }.toSeq } } errorRow should not be (null) //TODO: ErrorMessage //errorRow.message should equal(errorQueryResult.statusMessage) } private def doTestBreakdownsAreObfuscated(result: QueryResult): Unit = { result.breakdowns.values.map(_.data).foreach { actualBreakdowns => actualBreakdowns.keySet should equal(dummyBreakdownData.keySet) for { breakdownName <- actualBreakdowns.keySet } { within3(actualBreakdowns(breakdownName), dummyBreakdownData(breakdownName)) should be(true) } } } @Test def testGetBreakdownsWithRegularCountQuery { val breakdowns = DefaultBreakdownResultOutputTypes.values.map(breakdownFor) val resp = doTestGetBreakdowns(breakdowns) val firstResult = resp.results.head firstResult.resultType should equal(Some(PATIENT_COUNT_XML)) firstResult.setSize should equal(setSize) firstResult.description should equal(None) firstResult.breakdowns.keySet should equal(DefaultBreakdownResultOutputTypes.toSet) //NB: Verify that breakdowns are obfuscated doTestBreakdownsAreObfuscated(firstResult) resp.results.size should equal(1) } @Test def testGetBreakdownsSomeFailures { val resultTypesExpectedToSucceed = Seq(PATIENT_AGE_COUNT_XML, PATIENT_GENDER_COUNT_XML) val breakdowns = resultTypesExpectedToSucceed.map(breakdownFor) val resp = doTestGetBreakdowns(breakdowns) val firstResult = resp.results.head firstResult.resultType should equal(Some(PATIENT_COUNT_XML)) firstResult.setSize should equal(setSize) firstResult.description should equal(None) firstResult.breakdowns.keySet should equal(resultTypesExpectedToSucceed.toSet) //NB: Verify that breakdowns are obfuscated doTestBreakdownsAreObfuscated(firstResult) resp.results.size should equal(1) } @Test def testErrorResponsesArePassedThrough: Unit = { val errorResponse = ErrorResponse(TestProblem(summary = "blarg!")) val resp = doQuery(Set(PATIENT_COUNT_XML)) { errorResponse.toI2b2String } resp should equal(errorResponse) } private def breakdownFor(resultType: ResultOutputType) = I2b2ResultEnvelope(resultType, dummyBreakdownData) private def doTestGetBreakdowns(successfulBreakdowns: Seq[I2b2ResultEnvelope]): RunQueryResponse = { val outputTypes = justCounts ++ DefaultBreakdownResultOutputTypes.toSet val resp = doQueryThatReturnsSpecifiedBreakdowns(outputTypes, successfulBreakdowns) doBasicRunQueryResponseTest(resp) resp } private def doBasicRunQueryResponseTest(r: BaseShrineResponse) { val resp = r.asInstanceOf[RunQueryResponse] resp.createDate should equal(now) resp.groupId should equal(groupId) resp.queryId should equal(queryId) resp.queryInstanceId should equal(instanceId) resp.queryName should equal(queryDef.name) resp.requestXml should equal(queryDef) } private def doQueryThatReturnsSpecifiedBreakdowns(outputTypes: Set[ResultOutputType], successfulBreakdowns: Seq[I2b2ResultEnvelope]): RunQueryResponse = afterCreatingTablesReturn { val breakdownQueryResults = DefaultBreakdownResultOutputTypes.values.zipWithIndex.map { case (rt, i) => countQueryResult.withId(resultId + i + 1).withResultType(rt) } //Need this rigamarole to ensure that resultIds line up such that the type of breakdown the adapter asks for //(PATIENT_AGE_COUNT_XML, etc) is what the mock HttpClient actually returns. Here, we build up maps of QueryResults //and I2b2ResultEnvelopes, keyed on resultIds generated in the previous expression, to use to look up values to use //to build ReadResultResponses val successfulBreakdownsByType = successfulBreakdowns.map(e => e.resultType -> e).toMap val successfulBreakdownTypes = successfulBreakdownsByType.keySet val breakdownQueryResultsByResultId = breakdownQueryResults.collect { case qr if successfulBreakdownTypes(qr.resultType.get) => qr.resultId -> qr }.toMap val breakdownsToBeReturnedByResultId = breakdownQueryResultsByResultId.map { case (resultId, queryResult) => (resultId, successfulBreakdownsByType(queryResult.resultType.get)) } val expectedLocalTerm = Term("bar") val httpClient = new HttpClient { override def post(input: String, url: String): HttpResponse = { val resp = CrcRequest.fromI2b2String(DefaultBreakdownResultOutputTypes.toSet)(input) match { case Success(req: RunQueryRequest) => { //NB: Terms should be translated req.queryDefinition.expr.get should equal(expectedLocalTerm) //Credentials should be "translated" req.authn.username should equal(hiveCredentials.username) req.authn.domain should equal(hiveCredentials.domain) //I2b2 Project ID should be translated req.projectId should equal(hiveCredentials.projectId) val queryResultMap = RawCrcRunQueryResponse.toQueryResultMap(countQueryResult +: breakdownQueryResults) RawCrcRunQueryResponse(queryId, now, "userId", "groupId", queryDef, instanceId, queryResultMap) } //NB: return a ReadResultResponse with new breakdown data each time, but will throw if the asked-for breakdown //is not one of the ones passed to the enclosing method, simulating an error calling the CRC case Success(req: ReadResultRequest) => { val resultId = req.localResultId.toLong ReadResultResponse(xmlResultId, breakdownQueryResultsByResultId(resultId), breakdownsToBeReturnedByResultId(resultId)) } case _ => ??? //fail loudly } HttpResponse.ok(resp.toI2b2String) } } val result = doQuery(outputTypes, dao, httpClient) validateDb(successfulBreakdowns, breakdownQueryResultsByResultId) result.asInstanceOf[RunQueryResponse] } private def validateDb(breakdownsReturned: Seq[I2b2ResultEnvelope], breakdownQueryResultsByResultId: Map[Long, QueryResult]) { val expectedNetworkTerm = Term("foo") //We should have one row in the shrine_query table, for the query just performed val queryRow = first(queryRows) { queryRow.dateCreated should not be (null) queryRow.domain should equal(authn.domain) queryRow.name should equal(queryDef.name) queryRow.localId should equal(expectedLocalMasterId) queryRow.networkId should equal(expectedNetworkQueryId) queryRow.username should equal(authn.username) queryRow.queryDefinition.expr.get should equal(expectedNetworkTerm) } list(queryRows).size should equal(1) //We should have one row in the count_result table, with the right obfuscated value, which is within the expected amount from the original count val countRow = first(countResultRows) { countRow.creationDate should not be (null) countRow.originalValue should equal(countQueryResult.setSize) within3(countRow.obfuscatedValue, countQueryResult.setSize) should be(true) within3(countRow.obfuscatedValue, countRow.originalValue) should be(true) } list(countResultRows).size should equal(1) //We should have 5 rows in the query_result table, one for the count result and one for each of the 4 requested breakdown types val queryResults = list(queryResultRows) { val countQueryResultRow = queryResults.find(_.resultType == PATIENT_COUNT_XML).get countQueryResultRow.localId should equal(countQueryResult.resultId) countQueryResultRow.queryId should equal(queryRow.id) val resultIdsByResultType = breakdownQueryResultsByResultId.map { case (resultId, queryResult) => queryResult.resultType.get -> resultId }.toMap for (breakdownType <- DefaultBreakdownResultOutputTypes.values) { val breakdownQueryResultRow = queryResults.find(_.resultType == breakdownType).get breakdownQueryResultRow.queryId should equal(queryRow.id) //We'll have a result id if this breakdown type didn't fail if (resultIdsByResultType.contains(breakdownQueryResultRow.resultType)) { breakdownQueryResultRow.localId should equal(resultIdsByResultType(breakdownQueryResultRow.resultType)) } } } queryResults.size should equal(5) val returnedBreakdownTypes = breakdownsReturned.map(_.resultType).toSet val notReturnedBreakdownTypes = DefaultBreakdownResultOutputTypes.toSet -- returnedBreakdownTypes val errorResults = list(errorResultRows) //We should have a row in the error_result table for each breakdown that COULD NOT be retrieved { for { queryResult <- queryResults if notReturnedBreakdownTypes.contains(queryResult.resultType) resultType = queryResult.resultType resultId = queryResult.id } { errorResults.find(_.resultId == resultId).isDefined should be(true) } } errorResults.size should equal(notReturnedBreakdownTypes.size) //We should have properly-obfuscated rows in the breakdown_result table for each of the breakdown types that COULD be retrieved val breakdownResults = list(breakdownResultRows) val bdrs = breakdownResults.toIndexedSeq { for { queryResult <- queryResults if returnedBreakdownTypes.contains(queryResult.resultType) resultType = queryResult.resultType resultId = queryResult.id } { //Find all the rows for a particular breakdown type val rowsWithType = breakdownResults.filter(_.resultId == resultId) //Combining the rows should give the expected dummy data rowsWithType.map(row => row.dataKey -> row.originalValue).toMap should equal(dummyBreakdownData) for (breakdownRow <- rowsWithType) { within3(breakdownRow.obfuscatedValue, dummyBreakdownData(breakdownRow.dataKey)) should be(true) } } } } private def doQuery(outputTypes: Set[ResultOutputType])(i2b2XmlToReturn: => String): BaseShrineResponse = { doQuery(outputTypes, dao, MockHttpClient(i2b2XmlToReturn)) } private def doQuery(outputTypes: Set[ResultOutputType], adapterDao: AdapterDao, httpClient: HttpClient): BaseShrineResponse = { val translator = new QueryDefinitionTranslator(new ExpressionTranslator(Map("foo" -> Set("bar")))) //NB: Don't obfuscate, for simpler testing - val adapter = new RunQueryAdapter( - Poster("crc-url", httpClient), - adapterDao, - hiveCredentials, - translator, - adapterLockoutThreshold, + val adapter = RunQueryAdapter( + poster = Poster("crc-url", httpClient), + dao = adapterDao, + hiveCredentials = hiveCredentials, + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = adapterLockoutThreshold, doObfuscation = false, runQueriesImmediately = true, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty ) import scala.concurrent.duration._ val req = RunQueryRequest(projectId, 1.second, authn, expectedNetworkQueryId, Option(topicId), Option(topicName), outputTypes, queryDef) val networkAuthn = AuthenticationInfo("some-domain", "username", Credential("sadasdasdasd", false)) val broadcastMessage = BroadcastMessage(queryId, networkAuthn, req) adapter.processRequest(broadcastMessage) } } \ No newline at end of file diff --git a/adapter/adapter-service/src/test/scala/net/shrine/adapter/dao/MockAdapterDao.scala b/adapter/adapter-service/src/test/scala/net/shrine/adapter/dao/MockAdapterDao.scala index 6082dc874..d649e7013 100644 --- a/adapter/adapter-service/src/test/scala/net/shrine/adapter/dao/MockAdapterDao.scala +++ b/adapter/adapter-service/src/test/scala/net/shrine/adapter/dao/MockAdapterDao.scala @@ -1,61 +1,64 @@ package net.shrine.adapter.dao import net.shrine.adapter.dao.model.ShrineQuery import net.shrine.adapter.dao.model.ShrineQueryResult import net.shrine.protocol.AuthenticationInfo import net.shrine.protocol.I2b2ResultEnvelope import net.shrine.protocol.QueryResult import net.shrine.protocol.ResultOutputType import net.shrine.protocol.query.QueryDefinition +import scala.concurrent.duration.Duration import scala.xml.NodeSeq /** * @author clint * @since Oct 19, 2012 */ object MockAdapterDao extends MockAdapterDao trait MockAdapterDao extends AdapterDao { override def flagQuery(networkQueryId: Long, flagMessage: Option[String]): Unit = () override def unFlagQuery(networkQueryId: Long): Unit = () override def insertQuery(localMasterId: String, networkId: Long, authn: AuthenticationInfo, query: QueryDefinition, isFlagged: Boolean, hasBeenRun: Boolean, flagMessage: Option[String]): Int = 0 override def insertQueryResults(parentQueryId: Int, results: Seq[QueryResult]): Map[ResultOutputType, Seq[Int]] = Map.empty override def insertCountResult(resultId: Int, originalCount: Long, obfuscatedCount: Long): Unit = () override def insertBreakdownResults(parentResultIds: Map[ResultOutputType, Seq[Int]], originalBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope], obfuscatedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope]): Unit = () override def insertErrorResult(parentResultId: Int, errorMessage: String, codec:String, stampText:String, summary:String, digestDescription:String,detailsXml:NodeSeq) = () override def findQueryByNetworkId(networkQueryId: Long): Option[ShrineQuery] = None override def findQueriesByUserAndDomain(domain: String, username: String, howMany: Int): Seq[ShrineQuery] = Nil override def findQueriesByDomain(domain: String): Seq[ShrineQuery] = Nil override def findResultsFor(networkQueryId: Long): Option[ShrineQueryResult] = None + override def checkIfBot(authn:AuthenticationInfo, botTimeThresholds:Map[Long,Duration]): Unit = {} + override def isUserLockedOut(id: AuthenticationInfo, defaultThreshold: Int): Boolean = false override def renameQuery(networkQueryId: Long, newName: String): Unit = () override def deleteQuery(networkQueryId: Long): Unit = () override def deleteQueryResultsFor(networkQueryId: Long): Unit = () override def findRecentQueries(howMany: Int): Seq[ShrineQuery] = Nil override def storeResults(authn: AuthenticationInfo, masterId: String, networkQueryId: Long, queryDefinition: QueryDefinition, rawQueryResults: Seq[QueryResult], obfuscatedQueryResults: Seq[QueryResult], failedBreakdownTypes: Seq[ResultOutputType], mergedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope], obfuscatedBreakdowns: Map[ResultOutputType, I2b2ResultEnvelope]): Unit = () } \ No newline at end of file diff --git a/adapter/adapter-service/src/test/scala/net/shrine/adapter/service/I2b2AdminResourceEndToEndJaxrsTest.scala b/adapter/adapter-service/src/test/scala/net/shrine/adapter/service/I2b2AdminResourceEndToEndJaxrsTest.scala index f7fa9d870..194104b32 100644 --- a/adapter/adapter-service/src/test/scala/net/shrine/adapter/service/I2b2AdminResourceEndToEndJaxrsTest.scala +++ b/adapter/adapter-service/src/test/scala/net/shrine/adapter/service/I2b2AdminResourceEndToEndJaxrsTest.scala @@ -1,188 +1,189 @@ package net.shrine.adapter.service import org.junit.Test import net.shrine.adapter.HasI2b2AdminDao import net.shrine.protocol.{HiveCredentials, ReadI2b2AdminPreviousQueriesRequest, ReadI2b2AdminQueryingUsersRequest, ReadI2b2AdminQueryingUsersResponse, I2b2AdminUserWithRole, ErrorResponse, RunHeldQueryRequest, RunQueryResponse, RunQueryRequest, ResultOutputType, QueryResult, BroadcastMessage, AuthenticationInfo, Credential, DefaultBreakdownResultOutputTypes} import net.shrine.client.Poster import net.shrine.adapter.RunQueryAdapter import net.shrine.adapter.translators.QueryDefinitionTranslator import net.shrine.adapter.translators.ExpressionTranslator import net.shrine.client.HttpClient import net.shrine.client.HttpResponse import net.shrine.protocol.query.Term import scala.util.Success import net.shrine.util.XmlDateHelper import net.shrine.protocol.query.QueryDefinition /** * @author clint * @since Apr 12, 2013 * * NB: Ideally we would extend JerseyTest here, but since we have to extend AbstractDependencyInjectionSpringContextTests, * we get into a diamond-problem when extending JerseyTest as well, even when both of them are extended by shim traits. * * We work around this issue by mising in JerseyTestCOmponent, which brings in a JerseyTest by composition, and ensures * that it is set up and torn down properly. */ final class I2b2AdminResourceEndToEndJaxrsTest extends AbstractI2b2AdminResourceJaxrsTest with HasI2b2AdminDao { private[this] val dummyUrl = "http://example.com" private[this] val dummyText = "This is dummy text" private[this] val dummyMasterId = 873456L private[this] val dummyInstanceId = 99L private[this] val dummyResultId = 42L private[this] val dummySetSize = 12345L private[this] val networkAuthn = AuthenticationInfo("network-domain", "network-username", Credential("network-password", false)) private lazy val runQueryAdapter: RunQueryAdapter = { val translator = new QueryDefinitionTranslator(new ExpressionTranslator(Map("n1" -> Set("l1")))) val poster = new Poster(dummyUrl, new HttpClient { override def post(input: String, url: String): HttpResponse = { RunQueryRequest.fromI2b2String(DefaultBreakdownResultOutputTypes.toSet)(input) match { case Success(req) => { val queryResult = QueryResult(dummyResultId, dummyInstanceId, Some(ResultOutputType.PATIENT_COUNT_XML), dummySetSize, Some(XmlDateHelper.now), Some(XmlDateHelper.now), Some("desc"), QueryResult.StatusType.Finished, Some("status")) val resp = RunQueryResponse(dummyMasterId, XmlDateHelper.now, networkAuthn.username, networkAuthn.domain, req.queryDefinition, 123L, queryResult) HttpResponse.ok(resp.toI2b2String) } case _ => ??? } } }) RunQueryAdapter( - poster, - dao, - HiveCredentials("d", "u", "pwd", "pid"), - translator, - 1000, - false, - true, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + poster = poster, + dao = dao, + hiveCredentials = HiveCredentials("d", "u", "pwd", "pid"), + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = 1000, + doObfuscation = false, + runQueriesImmediately = true, + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty ) } override def makeHandler = new I2b2AdminService(dao, i2b2AdminDao, Poster(dummyUrl, AlwaysAuthenticatesMockPmHttpClient), runQueryAdapter) @Test def testReadQueryDefinition = afterLoadingTestData { doTestReadQueryDefinition(networkQueryId1, Some((queryName1, queryDef1))) } @Test def testReadQueryDefinitionUnknownQueryId = afterLoadingTestData { doTestReadQueryDefinition(87134682364L, None) } import ReadI2b2AdminPreviousQueriesRequest.{Username, Category, SortOrder} import Username._ @Test def testReadI2b2AdminPreviousQueries = afterLoadingTestData { val searchString = queryName1 val maxResults = 123 val sortOrder = ReadI2b2AdminPreviousQueriesRequest.SortOrder.Ascending val categoryToSearchWithin = ReadI2b2AdminPreviousQueriesRequest.Category.All val searchStrategy = ReadI2b2AdminPreviousQueriesRequest.Strategy.Exact val request = ReadI2b2AdminPreviousQueriesRequest(projectId, waitTime, authn, All, searchString, maxResults, None, sortOrder, searchStrategy, categoryToSearchWithin) doTestReadI2b2AdminPreviousQueries(request, Seq(queryMaster1)) } @Test def testReadI2b2AdminPreviousQueriesNoResultsExpected = afterLoadingTestData { //A request that won't return anything val request = ReadI2b2AdminPreviousQueriesRequest(projectId, waitTime, authn, All, "askjdhakfgkafgkasf", 123, None) doTestReadI2b2AdminPreviousQueries(request, Nil) } @Test def testReadI2b2AdminPreviousQueriesExcludeUser: Unit = afterLoadingTestData { val request = ReadI2b2AdminPreviousQueriesRequest(projectId, waitTime, authn, Except(authn2.username), "", 10, None) doTestReadI2b2AdminPreviousQueries(request, Seq(queryMaster2, queryMaster1)) } @Test def testReadI2b2AdminPreviousQueriesOnlyFlagged: Unit = afterLoadingTestData { val request = ReadI2b2AdminPreviousQueriesRequest(projectId, waitTime, authn, All, "", 10, None, categoryToSearchWithin = Category.Flagged) doTestReadI2b2AdminPreviousQueries(request, Seq(queryMaster4, queryMaster1)) } @Test def testReadPreviousQueriesOnlyFlaggedExcludingUser: Unit = afterLoadingTestData { val request = ReadI2b2AdminPreviousQueriesRequest(projectId, waitTime, authn, Except(authn.username), "", 10, None, categoryToSearchWithin = Category.Flagged) doTestReadI2b2AdminPreviousQueries(request, Seq(queryMaster4)) } @Test def testReadPreviousQueriesExcludingUserWithSearchString: Unit = afterLoadingTestData { val request = ReadI2b2AdminPreviousQueriesRequest(projectId, waitTime, authn, All, queryName1, 10, None, categoryToSearchWithin = Category.Flagged) doTestReadI2b2AdminPreviousQueries(request, Seq(queryMaster1)) } @Test def testReadI2b2QueryingUsers = afterLoadingTestData { val request = ReadI2b2AdminQueryingUsersRequest(projectId, waitTime, authn, "foo") val ReadI2b2AdminQueryingUsersResponse(users) = adminClient.readI2b2AdminQueryingUsers(request) users.toSet should equal(Set(I2b2AdminUserWithRole(shrineProjectId, authn.username, "USER"), I2b2AdminUserWithRole(shrineProjectId, authn2.username, "USER"))) } @Test def testReadI2b2QueryingUsersNoResultsExpected = afterCreatingTables { val request = ReadI2b2AdminQueryingUsersRequest(projectId, waitTime, authn, "foo") val ReadI2b2AdminQueryingUsersResponse(users) = adminClient.readI2b2AdminQueryingUsers(request) //DB is empty, so no users will be returned users should equal(Nil) } @Test def testRunHeldQueryUnknownQuery = afterCreatingTables { val request = RunHeldQueryRequest(projectId, waitTime, authn, 12345L) val resp = adminClient.runHeldQuery(request) resp.isInstanceOf[ErrorResponse] should be(true) } @Test def testRunHeldQueryKnownQuery = afterCreatingTables { val networkQueryId = 12345L val request = RunHeldQueryRequest(projectId, waitTime, authn, networkQueryId) val queryName = "aslkdjasljkd" val queryExpr = Term("n1") val runQueryReq = RunQueryRequest(projectId, waitTime, authn, networkQueryId, None, None, Set(ResultOutputType.PATIENT_COUNT_XML), QueryDefinition(queryName, queryExpr)) runQueryAdapter.copy(runQueriesImmediately = false).processRequest(BroadcastMessage(networkAuthn, runQueryReq)) val resp = adminClient.runHeldQuery(request) val runQueryResp = resp.asInstanceOf[RunQueryResponse] runQueryResp.createDate should not be(null) runQueryResp.groupId should be(networkAuthn.domain) runQueryResp.userId should equal(networkAuthn.username) runQueryResp.queryId should equal(dummyMasterId) runQueryResp.singleNodeResult.setSize should equal(dummySetSize) runQueryResp.singleNodeResult.resultType should equal(Some(ResultOutputType.PATIENT_COUNT_XML)) //TODO runQueryResp.requestXml.name should equal(queryName) runQueryResp.requestXml.expr.get should equal(Term("l1")) } } diff --git a/integration/src/test/scala/net/shrine/integration/NetworkSimulationTest.scala b/integration/src/test/scala/net/shrine/integration/NetworkSimulationTest.scala index ca8f92f88..5f0a3c99b 100644 --- a/integration/src/test/scala/net/shrine/integration/NetworkSimulationTest.scala +++ b/integration/src/test/scala/net/shrine/integration/NetworkSimulationTest.scala @@ -1,337 +1,338 @@ package net.shrine.integration import java.net.URL import net.shrine.log.Loggable import scala.concurrent.Future import scala.concurrent.duration.DurationInt import org.junit.Test import net.shrine.util.ShouldMatchersForJUnit import net.shrine.adapter.AdapterMap import net.shrine.adapter.DeleteQueryAdapter import net.shrine.adapter.client.AdapterClient import net.shrine.adapter.dao.squeryl.AbstractSquerylAdapterTest import net.shrine.adapter.service.AdapterRequestHandler import net.shrine.adapter.service.AdapterService import net.shrine.broadcaster.AdapterClientBroadcaster import net.shrine.broadcaster.NodeHandle import net.shrine.crypto.DefaultSignerVerifier import net.shrine.crypto.TestKeystore import net.shrine.protocol.{HiveCredentials, AuthenticationInfo, BroadcastMessage, Credential, DeleteQueryRequest, DeleteQueryResponse, NodeId, Result, RunQueryRequest, CertId, RequestType, FlagQueryRequest, FlagQueryResponse, RawCrcRunQueryResponse, ResultOutputType, QueryResult, RunQueryResponse, AggregatedRunQueryResponse, UnFlagQueryRequest, UnFlagQueryResponse, DefaultBreakdownResultOutputTypes} import net.shrine.qep.QepService import net.shrine.broadcaster.SigningBroadcastAndAggregationService import net.shrine.broadcaster.InJvmBroadcasterClient import net.shrine.adapter.FlagQueryAdapter import net.shrine.protocol.query.Term import net.shrine.adapter.RunQueryAdapter import net.shrine.client.Poster import net.shrine.client.HttpClient import net.shrine.client.HttpResponse import net.shrine.adapter.translators.QueryDefinitionTranslator import net.shrine.adapter.translators.ExpressionTranslator import net.shrine.util.XmlDateHelper import net.shrine.adapter.ReadQueryResultAdapter import net.shrine.protocol.query.QueryDefinition import net.shrine.adapter.UnFlagQueryAdapter import net.shrine.crypto.SigningCertStrategy /** * @author clint * @since Nov 27, 2013 * * An in-JVM simulation of a Shrine network with one hub and 4 downstream adapters. * * The hub and adapters are wired up with mock AdapterClients that do in-JVM communication via method calls * instead of remotely. * * The adapters are configured to respond with valid results for DeleteQueryRequests * only. Other requests could be handled, but that would not provide benefit to offset the effort of wiring * up more and more-complex Adapters. * * The test network is queried, and the final result, as well as the state of each adapter, is inspected to * ensure that the right messages were sent between elements of the system. * */ final class NetworkSimulationTest extends AbstractSquerylAdapterTest with ShouldMatchersForJUnit { private val certCollection = TestKeystore.certCollection private lazy val myCertId: CertId = certCollection.myCertId.get private lazy val signerVerifier = new DefaultSignerVerifier(certCollection) private val domain = "test-domain" private val username = "test-username" private val password = "test-password" import NetworkSimulationTest._ import scala.concurrent.duration._ private def deleteQueryAdapter: DeleteQueryAdapter = new DeleteQueryAdapter(dao) private def flagQueryAdapter: FlagQueryAdapter = new FlagQueryAdapter(dao) private def unFlagQueryAdapter: UnFlagQueryAdapter = new UnFlagQueryAdapter(dao) private def mockPoster = Poster("http://example.com", new HttpClient { override def post(input: String, url: String): HttpResponse = ??? }) private val hiveCredentials = HiveCredentials("d", "u", "pwd", "pid") private def queuesQueriesRunQueryAdapter: RunQueryAdapter = { val translator = new QueryDefinitionTranslator(new ExpressionTranslator(Map("n1" -> Set("l1")))) RunQueryAdapter( - mockPoster, - dao, - hiveCredentials, - translator, - 10000, + poster = mockPoster, + dao = dao, + hiveCredentials = hiveCredentials, + conceptTranslator = translator, + adapterLockoutAttemptsThreshold = 10000, doObfuscation = false, runQueriesImmediately = false, - DefaultBreakdownResultOutputTypes.toSet, - collectAdapterAudit = false + breakdownTypes = DefaultBreakdownResultOutputTypes.toSet, + collectAdapterAudit = false, + botCountTimeThresholds = Map.empty //todo this might be the right place to test bot defense ) } private def immediatelyRunsQueriesRunQueryAdapter(setSize: Long): RunQueryAdapter = { val mockCrcPoster = Poster("http://example.com", new HttpClient { override def post(input: String, url: String): HttpResponse = { val req = RunQueryRequest.fromI2b2String(DefaultBreakdownResultOutputTypes.toSet)(input).get val now = XmlDateHelper.now val queryResult = QueryResult(1L, 42L, Some(ResultOutputType.PATIENT_COUNT_XML), setSize, Some(now), Some(now), Some("desc"), QueryResult.StatusType.Finished, Some("status")) val mockCrcXml = RawCrcRunQueryResponse(req.networkQueryId, XmlDateHelper.now, req.authn.username, req.projectId, req.queryDefinition, 42L, Map(ResultOutputType.PATIENT_COUNT_XML -> Seq(queryResult))).toI2b2String HttpResponse.ok(mockCrcXml) } }) queuesQueriesRunQueryAdapter.copy(poster = mockCrcPoster, runQueriesImmediately = true) } private def readQueryResultAdapter(setSize: Long): ReadQueryResultAdapter = { new ReadQueryResultAdapter( mockPoster, hiveCredentials, dao, doObfuscation = false, DefaultBreakdownResultOutputTypes.toSet, collectAdapterAudit = false ) } private lazy val adaptersByNodeId: Seq[(NodeId, MockAdapterRequestHandler)] = { import NodeName._ import RequestType.{ MasterDeleteRequest => MasterDeleteRequestRT, FlagQueryRequest => FlagQueryRequestRT, QueryDefinitionRequest => RunQueryRT, GetQueryResult => ReadQueryResultRT, UnFlagQueryRequest => UnFlagQueryRequestRT } (for { (childName, setSize) <- Seq((A, 1L), (B, 2L), (C, 3L), (D, 4L)) } yield { val nodeId = NodeId(childName.name) val maxSignatureAge = 1.hour val adapterMap = AdapterMap(Map( MasterDeleteRequestRT -> deleteQueryAdapter, FlagQueryRequestRT -> flagQueryAdapter, UnFlagQueryRequestRT -> unFlagQueryAdapter, RunQueryRT -> queuesQueriesRunQueryAdapter, ReadQueryResultRT -> readQueryResultAdapter(setSize))) nodeId -> MockAdapterRequestHandler(new AdapterService(nodeId, signerVerifier, maxSignatureAge, adapterMap)) }) } private lazy val shrineService: QepService = { val destinations: Set[NodeHandle] = { (for { (nodeId, adapterRequestHandler) <- adaptersByNodeId } yield { NodeHandle(nodeId, MockAdapterClient(nodeId, adapterRequestHandler)) }).toSet } QepService( "example.com", MockAuditDao, MockAuthenticator, MockQueryAuthorizationService, true, SigningBroadcastAndAggregationService(InJvmBroadcasterClient(AdapterClientBroadcaster(destinations, MockHubDao)), signerVerifier, SigningCertStrategy.Attach), 1.hour, DefaultBreakdownResultOutputTypes.toSet, false) } @Test def testSimulatedNetwork = afterCreatingTables { val authn = AuthenticationInfo(domain, username, Credential(password, false)) val masterId = 12345L import scala.concurrent.duration._ val req = DeleteQueryRequest("some-project-id", 1.second, authn, masterId) val resp = shrineService.deleteQuery(req, true) for { (nodeId, mockAdapter) <- adaptersByNodeId } { mockAdapter.lastMessage.networkAuthn.domain should equal(authn.domain) mockAdapter.lastMessage.networkAuthn.username should equal(authn.username) mockAdapter.lastMessage.request should equal(req) mockAdapter.lastResult.response should equal(DeleteQueryResponse(masterId)) } resp should equal(DeleteQueryResponse(masterId)) } @Test def testQueueQuery = afterCreatingTables { val authn = AuthenticationInfo(domain, username, Credential(password, false)) val topicId = "askldjlkas" val topicName = "Topic Name" val queryName = "lsadj3028940" import scala.concurrent.duration._ val runQueryReq = RunQueryRequest("some-project-id", 1.second, authn, 12345L, Some(topicId), Some(topicName), Set(ResultOutputType.PATIENT_COUNT_XML), QueryDefinition(queryName, Term("n1"))) val aggregatedRunQueryResp = shrineService.runQuery(runQueryReq, true).asInstanceOf[AggregatedRunQueryResponse] var broadcastMessageId: Option[Long] = None //Broadcast the original run query request; all nodes should queue the query for { (nodeId, mockAdapter) <- adaptersByNodeId } { broadcastMessageId = Option(mockAdapter.lastMessage.requestId) mockAdapter.lastMessage.networkAuthn.domain should equal(authn.domain) mockAdapter.lastMessage.networkAuthn.username should equal(authn.username) val lastReq = mockAdapter.lastMessage.request.asInstanceOf[RunQueryRequest] lastReq.authn should equal(runQueryReq.authn) lastReq.requestType should equal(runQueryReq.requestType) lastReq.waitTime should equal(runQueryReq.waitTime) //todo what to do with this check? lastReq.networkQueryId should equal(mockAdapter.lastMessage.requestId) lastReq.outputTypes should equal(runQueryReq.outputTypes) lastReq.projectId should equal(runQueryReq.projectId) lastReq.queryDefinition should equal(runQueryReq.queryDefinition) lastReq.topicId should equal(runQueryReq.topicId) val runQueryResp = mockAdapter.lastResult.response.asInstanceOf[RunQueryResponse] runQueryResp.queryId should equal(-1L) runQueryResp.singleNodeResult.statusType should equal(QueryResult.StatusType.Held) runQueryResp.singleNodeResult.setSize should equal(-1L) } aggregatedRunQueryResp.queryId should equal(broadcastMessageId.get) aggregatedRunQueryResp.results.map(_.setSize) should equal(Seq(-1L, -1L, -1L, -1L, -4L)) } @Test def testFlagQuery = afterCreatingTables { val authn = AuthenticationInfo(domain, username, Credential(password, false)) val masterId = 12345L import scala.concurrent.duration._ val networkQueryId = 9999L val name = "some query" val expr = Term("foo") val fooQuery = QueryDefinition(name,expr) dao.insertQuery(masterId.toString, networkQueryId, authn, fooQuery, isFlagged = false, hasBeenRun = true, flagMessage = None) dao.findQueryByNetworkId(networkQueryId).get.isFlagged should be(false) dao.findQueryByNetworkId(networkQueryId).get.flagMessage should be(None) val req = FlagQueryRequest("some-project-id", 1.second, authn, networkQueryId, Some("foo")) val resp = shrineService.flagQuery(req, true) resp should equal(FlagQueryResponse) dao.findQueryByNetworkId(networkQueryId).get.isFlagged should be(true) dao.findQueryByNetworkId(networkQueryId).get.flagMessage should be(Some("foo")) } @Test def testUnFlagQuery = afterCreatingTables { val authn = AuthenticationInfo(domain, username, Credential(password, false)) val masterId = 12345L import scala.concurrent.duration._ val networkQueryId = 9999L val flagMsg = Some("foo") val name = "some query" val expr = Term("foo") val fooQuery = QueryDefinition(name,expr) dao.insertQuery(masterId.toString, networkQueryId, authn, fooQuery, isFlagged = true, hasBeenRun = true, flagMessage = flagMsg) dao.findQueryByNetworkId(networkQueryId).get.isFlagged should be(true) dao.findQueryByNetworkId(networkQueryId).get.flagMessage should be(flagMsg) val req = UnFlagQueryRequest("some-project-id", 1.second, authn, networkQueryId) val resp = shrineService.unFlagQuery(req, true) resp should equal(UnFlagQueryResponse) dao.findQueryByNetworkId(networkQueryId).get.isFlagged should be(false) dao.findQueryByNetworkId(networkQueryId).get.flagMessage should be(None) } } object NetworkSimulationTest { private final case class MockAdapterClient(nodeId: NodeId, adapter: AdapterRequestHandler) extends AdapterClient with Loggable { import scala.concurrent.ExecutionContext.Implicits.global override def query(message: BroadcastMessage): Future[Result] = Future.successful { debug(s"Invoking Adapter $nodeId with $message") val result = adapter.handleRequest(message) debug(s"Got result from $nodeId: $result") result } override def url: Option[URL] = ??? } private final case class MockAdapterRequestHandler(delegate: AdapterRequestHandler) extends AdapterRequestHandler { @volatile var lastMessage: BroadcastMessage = _ @volatile var lastResult: Result = _ override def handleRequest(request: BroadcastMessage): Result = { lastMessage = request val result = delegate.handleRequest(request) lastResult = result result } } }