diff --git a/src/main/scala/org/warcbase/spark/matchbox/ExtractEntities.scala b/src/main/scala/org/warcbase/spark/matchbox/ExtractEntities.scala index 2514e0f..146c51d 100644 --- a/src/main/scala/org/warcbase/spark/matchbox/ExtractEntities.scala +++ b/src/main/scala/org/warcbase/spark/matchbox/ExtractEntities.scala @@ -1,50 +1,50 @@ package org.warcbase.spark.matchbox import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD /** * Extracts entities */ object ExtractEntities { /** - * @param classifier path of NER3Classifier + * @param classifier NER3Classifier * @param inputRecordFile path of ARC or WARC file from which to extract entities * @param outputFile path of output directory */ def extractFromRecords(classifier: NER3Classifier, inputRecordFile: String, outputFile: String, sc: SparkContext): RDD[(String, String, String)] = { val rdd = RecordLoader.loadArc(inputRecordFile, sc) .map(r => (r.getCrawldate, r.getUrl, r.getRawBodyContent)) extractAndOutput(classifier, rdd, outputFile) } /** - * @param classifier path of NER3Classifier + * @param classifier NER3Classifier * @param inputFile path of file with tuples (date: String, url: String, content: String) * from which to extract entities * @param outputFile path of output directory */ def extractFromScrapeText(classifier: NER3Classifier, inputFile: String, outputFile: String, sc: SparkContext): RDD[(String, String, String)] = { val rdd = sc.textFile(inputFile) .map(line => { val ind1 = line.indexOf(",") val ind2 = line.indexOf(",", ind1 + 1) (line.substring(1, ind1), line.substring(ind1 + 1, ind2), line.substring(ind2 + 1, line.length - 1)) }) extractAndOutput(classifier, rdd, outputFile) } /** * @param classifier path of NER3Classifier * @param rdd with values (date, url, content) * @param outputFile path of output directory */ def extractAndOutput(classifier: NER3Classifier, rdd: RDD[(String, String, String)], outputFile: String): RDD[(String, String, String)] = { val r = rdd.map(r => (r._1, r._2, classifier.classify(r._3))) r.saveAsTextFile(outputFile) r } } diff --git a/src/test/scala/org/warcbase/spark/matchbox/ExtractEntitiesTest.scala b/src/test/scala/org/warcbase/spark/matchbox/ExtractEntitiesTest.scala index 0033d04..4674fdd 100644 --- a/src/test/scala/org/warcbase/spark/matchbox/ExtractEntitiesTest.scala +++ b/src/test/scala/org/warcbase/spark/matchbox/ExtractEntitiesTest.scala @@ -1,72 +1,71 @@ package org.warcbase.spark.matchbox import java.io.File import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import com.google.common.io.{Files, Resources} import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory import org.apache.spark.{SparkConf, SparkContext} -import org.junit.runner.RunWith -import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfter, FunSuite} import scala.collection.mutable -@RunWith(classOf[JUnitRunner]) +// This test requires the `iNerClassifierFile` to point to the correct path of the classifier +// @RunWith(classOf[JUnitRunner]) class ExtractEntitiesTest extends FunSuite with BeforeAndAfter { private val LOG = LogFactory.getLog(classOf[ExtractEntitiesTest]) private val scrapePath = Resources.getResource("ner/example.txt").getPath private val arcPath = Resources.getResource("arc/example.arc.gz").getPath private val master = "local[4]" private val appName = "example-spark" private var sc: SparkContext = _ private var tempDir: File = _ private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) - private val iNerClassfierFile = + private val iNerClassifierFile = Resources.getResource("ner/classifiers/english.all.3class.distsim.crf.ser.gz").getPath - private val classifier = new NER3Classifier(iNerClassfierFile) + private val classifier = new NER3Classifier(iNerClassifierFile) before { val conf = new SparkConf() .setMaster(master) .setAppName(appName) sc = new SparkContext(conf) tempDir = Files.createTempDir() LOG.info("Output can be found in " + tempDir.getPath) } test("extract entities") { val classifier_ = classifier val e = ExtractEntities.extractFromScrapeText(classifier_, scrapePath, tempDir + "/scrapeTextEntities", sc).take(3).last val expectedEntityMap = mutable.Map[NERClassType.Value, List[String]]() expectedEntityMap.put(NERClassType.PERSON, List()) expectedEntityMap.put(NERClassType.LOCATION, List("Teoma")) expectedEntityMap.put(NERClassType.ORGANIZATION, List()) assert(e._1 == "20080430") assert(e._2 == "http://www.archive.org/robots.txt") val actual = mapper.readValue(e._3, classOf[Map[String, List[String]]]) expectedEntityMap.toStream.foreach(f => { assert(f._2 == actual.get(f._1.toString).get) }) } test("ner3classifier") { val classifier_ = classifier val rdd = RecordLoader.loadArc(arcPath, sc) .map(r => (r.getCrawldate, r.getUrl, r.getRawBodyContent)) val entities = rdd.map(r => (r._1, r._2, classifier_.classify(r._3))) entities.take(3).foreach(println) } after { FileUtils.deleteDirectory(tempDir) LOG.info("Removing tmp files in " + tempDir.getPath) if (sc != null) { sc.stop() } } }