Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 14.50 KB | None | 0 0
  1. package scala.lms.tutorial
  2.  
  3. import lms.core.stub._
  4. import lms.core.utils
  5. import lms.macros.SourceContext
  6. import lms.core.virtualize
  7. import scala.collection.{mutable,immutable}
  8. import org.apache.hadoop.fs.{FileSystem, Path, FileStatus, BlockLocation}
  9. import org.apache.hadoop.hdfs.DistributedFileSystem;
  10. import org.apache.hadoop.hdfs.protocol.DatanodeInfo;
  11. import org.apache.hadoop.hdfs.tools.DFSck;
  12. import org.apache.hadoop.conf.Configuration;
  13.  
  14. @virtualize
  15. class MPI2Test extends TutorialFunSuite {
  16.   val under = "mpi_"
  17.  
  18. /**
  19. MPI API
  20. -------
  21.  
  22. Using MPI requires a few additional headers and support functions, and programs are
  23. typically compiled and launched with the `mpicc` and `mpirun` tools. We define a
  24. subclass of `DslDriver` that contains the necessary infrastructure.
  25. */
  26.  
  27.   abstract class MPIDriver[T:Manifest,U:Manifest] extends DslDriverC[T,U] with ScannerLowerExp { q =>
  28.     override val codegen = new DslGenC with CGenScannerLower with Run.CGenPreamble {
  29.       val IR: q.type = q
  30.     }
  31.     codegen.registerHeader("<mpi.h>")
  32.     codegen.registerHeader("<string.h>")
  33.     compilerCommand = "mpicc"
  34.     override def eval(a: T): Unit = {
  35.       import scala.sys.process._
  36.       import lms.core.utils._
  37.       val f1 = f; // compile!
  38.       def f2(a: T) = (s"mpirun /tmp/snippet $a": ProcessBuilder).lines.foreach(Console.println _)
  39.       time("eval")(f2(a))
  40.     }
  41.  
  42.     var pid: Rep[Int] = null
  43.     var nprocs: Rep[Int] = null
  44.  
  45.     override def wrapper(x: Rep[T]): Rep[U] = {
  46.       unchecked[Unit]("int argc = 0; char** argv = (char**)malloc(0); int provided");
  47.       unchecked[Unit]("MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided)")
  48.  
  49.       var nprocs1 = 0
  50.       unchecked[Unit]("MPI_Comm_size(MPI_COMM_WORLD, &", nprocs1, ")")
  51.  
  52.       var myrank = 0
  53.       unchecked[Unit]("MPI_Comm_rank(MPI_COMM_WORLD, &", myrank, ")")
  54.  
  55.       unchecked[Unit]("MPI_Request req")
  56.       unchecked[Unit]("MPI_Status status")
  57.  
  58.       pid = readVar(myrank)
  59.       nprocs = readVar(nprocs1)
  60.       val res = super.wrapper(x)
  61.  
  62.       unchecked[Unit]("MPI_Finalize()")
  63.       res
  64.     }
  65.  
  66.     def MPI_Issend(msg: Rep[Array[Int]], off: Rep[Int], len: Rep[Int], dst: Rep[Int]) = unchecked[Unit]("MPI_Issend(",msg," + (",off,"), ",len,", MPI_INT, ",dst,", 0, MPI_COMM_WORLD, &req)")
  67.     def MPI_Irecv(msg: Rep[Array[Int]], off: Rep[Int], len: Rep[Int], src: Rep[Int]) = unchecked[Unit]("MPI_Irecv(",msg," + (",off,"), ",len,", MPI_INT, ",src,", 0, MPI_COMM_WORLD, &req)")
  68.     def MPI_Barrier() = unchecked[Unit]("MPI_Barrier(MPI_COMM_WORLD)")
  69.  
  70.     abstract class RField {
  71.       def print()
  72.       def compare(o: RField): Rep[Boolean]
  73.       def hash: Rep[Long]
  74.     }
  75.     type Schema = Vector[String]
  76.     case class RString(data: Rep[String], len: Rep[Int]) extends RField {
  77.       def print() = printf("%.*s", len, data)//prints(data)
  78.       def printLen() = printf("%.*s", len, data)//printl(data, len)
  79.       def compare(o: RField) = o match { case RString(data2, len2) => if (len == len2) {
  80.         // TODO: we may or may not want to inline this (code bloat and icache considerations).
  81.         var i = 0
  82.         while (i < len && data.charAt(i) == data2.charAt(i)) {
  83.           i += 1
  84.         }
  85.         i == len
  86.       } else false }
  87.       def hash = data.HashCode(len)
  88.     }
  89.     case class RInt(value: Rep[Int]) extends RField {
  90.       def print() = printf("%d",value)
  91.       def compare(o: RField) = o match { case RInt(v2) => value == v2 }
  92.       def hash = value.asInstanceOf[Rep[Long]]
  93.     }
  94.  
  95.     type Fields = Vector[RField]
  96.     def fieldsEqual(a: Fields, b: Fields) = (a zip b).foldLeft(unit(true)) { (a,b) => b._1 compare b._2 }
  97.  
  98.     def fieldsHash(a: Fields) = a.foldLeft(unit(0L)) { _ * 41L + _.hash }
  99.  
  100.     def isNumericCol(s: String) = s.startsWith("#")
  101.  
  102.     case class Record(fields: Fields, schema: Schema) {
  103.       def apply(key: String): RField = fields(schema indexOf key)
  104.       def apply(keys: Schema): Fields = keys.map(this apply _)
  105.     }
  106.  
  107.     abstract class ColBuffer
  108.     case class IntColBuffer(data: Rep[Array[Int]]) extends ColBuffer
  109.     case class StringColBuffer(data: Rep[Array[String]], len: Rep[Array[Int]]) extends ColBuffer
  110.  
  111.     class ArrayBuffer(dataSize: Int, schema: Schema) {
  112.       val buf = schema.map {
  113.         case hd if isNumericCol(hd) => IntColBuffer(NewArray[Int](dataSize))
  114.         case _ => StringColBuffer(NewArray[String](dataSize), NewArray[Int](dataSize))
  115.       }
  116.  
  117.       var len = 0
  118.       def +=(x: Fields) = {
  119.         this(len) = x
  120.         len += 1
  121.       }
  122.       def update(i: Rep[Int], x: Fields) = (buf,x).zipped.foreach {
  123.         case (IntColBuffer(b), RInt(x)) => b(i) = x
  124.         case (StringColBuffer(b,l), RString(x,y)) => b(i) = x; l(i) = y
  125.       }
  126.       def apply(i: Rep[Int]) = buf.map {
  127.         case IntColBuffer(b) => RInt(b(i))
  128.         case StringColBuffer(b,l) => RString(b(i),l(i))
  129.       }
  130.     }
  131.  
  132.     object hashDefaults {
  133.       val hashSize   = (1 << 8)
  134.       val keysSize   = hashSize
  135.       val bucketSize = (1 << 8)
  136.       val dataSize   = keysSize * bucketSize
  137.     }
  138.  
  139.     // common base class to factor out commonalities of group and join hash tables
  140.  
  141.     class HashMapBase(keySchema: Schema, schema: Schema) {
  142.       import hashDefaults._
  143.  
  144.       val keys = new ArrayBuffer(keysSize, keySchema)
  145.       val keyCount = var_new(0)
  146.  
  147.       val hashMask = hashSize - 1
  148.       val htable = NewArray[Int](hashSize)
  149.       for (i <- 0 until hashSize :Rep[Range]) { htable(i) = -1 }
  150.  
  151.       def lookup(k: Fields) = lookupInternal(k,None)
  152.       def lookupOrUpdate(k: Fields)(init: Rep[Int]=>Rep[Unit]) = lookupInternal(k,Some(init))
  153.       def lookupInternal(k: Fields, init: Option[Rep[Int]=>Rep[Unit]]): Rep[Int] =
  154.       comment[Int]("hash_lookup") {
  155.         val h = fieldsHash(k).toInt
  156.         var pos = h & hashMask
  157.         while (htable(pos) != -1 && !fieldsEqual(keys(htable(pos)),k)) {
  158.           pos = (pos + 1) & hashMask
  159.         }
  160.         if (init.isDefined) {
  161.           if (htable(pos) == -1) {
  162.             val keyPos = keyCount: Rep[Int] // force read
  163.             keys(keyPos) = k
  164.             keyCount += 1
  165.             htable(pos) = keyPos
  166.             init.get(keyPos)
  167.             keyPos
  168.           } else {
  169.             htable(pos)
  170.           }
  171.         } else {
  172.           htable(pos)
  173.         }
  174.       }
  175.     }
  176.  
  177.   class HashMapAgg(keySchema: Schema, schema: Schema) extends HashMapBase(keySchema: Schema, schema: Schema) {
  178.     import hashDefaults._
  179.  
  180.       val values = new ArrayBuffer(keysSize, schema) // assuming all summation fields are numeric
  181.  
  182.       def apply(k: Fields) = new {
  183.         def +=(v: Fields) = {
  184.           val keyPos = lookupOrUpdate(k) { keyPos =>
  185.             values(keyPos) = schema.map(_ => RInt(0))
  186.           }
  187.           values(keyPos) = (values(keyPos) zip v) map { case (RInt(x), RInt(y)) => RInt(x + y) }
  188.         }
  189.       }
  190.  
  191.       def foreach(f: (Fields,Fields) => Rep[Unit]): Rep[Unit] = {
  192.         for (i <- 0 until keyCount) {
  193.           f(keys(i),values(i))
  194.         }
  195.       }
  196.  
  197.     }
  198.   }
  199.  
  200. /**
  201. ### Staged and Distributed Implementation
  202.  
  203. TODO / Exercise: complete the implementation by writing an mmap-based
  204. Scanner (assuming each cluster node has access to a common file system)
  205. and adapting the hash table implementation used for `GroupBy` in the
  206. query tutorial to the distributed setting with communication along the
  207. lines used in the character histogram above.
  208. */
  209.   test("wordcount_staged_seq") {
  210.     //@virtualize
  211.     val snippet = new MPIDriver[String,Unit] {
  212.  
  213.       def StringScanner(input: String) = new {
  214.         val data = uncheckedPure[Array[Char]](unit(input))
  215.         val pos = var_new(0)
  216.         def next(d: Rep[Char]) = {
  217.           val start: Rep[Int] = pos // force read
  218.           while (data(pos) != d) pos += 1
  219.           val len:Rep[Int] = pos - start
  220.           pos += 1
  221.           RString(stringFromCharArray(data,start,len), len)
  222.         }
  223.         def hasNext = pos < input.length
  224.       }
  225.       trait DataLoop {
  226.         def foreach(f: RString => Unit): Unit
  227.       }
  228.  
  229.       def parse(str: String) = new DataLoop {
  230.         val sc = StringScanner(str)
  231.         def foreach(f: RString => Unit) = {
  232.           while(sc.hasNext) {
  233.             f(sc.next(' '))
  234.           }
  235.         }
  236.       }
  237.  
  238.       def snippet(arg: Rep[String]): Rep[Unit] = {
  239.         if (pid == 0) {
  240.           val input = "foo bar baz foo bar foo foo foo boom bang boom boom yum"
  241.           val keySchema = Vector("word")
  242.           val dataSchema = Vector("#count")
  243.           val hm = new HashMapAgg(keySchema, dataSchema)
  244.        
  245.         // loop through string one word at a time
  246.           parse(input).foreach { word: RString =>
  247.             val key = Vector(word)
  248.             hm(key) += Vector(RInt(1))
  249.           }
  250.        
  251.           hm.foreach {
  252.             case (key, v) =>
  253.               key.head.asInstanceOf[RString].printLen() // force cast to RString for printLen
  254.               printf(" ")
  255.               v.head.print()
  256.               printf("\n")
  257.           }
  258.         }
  259.  
  260.  
  261.         /*input.foreach { c =>
  262.           histogram(c) += 1
  263.         }
  264.  
  265.         histogram.exchange()
  266.  
  267.         histogram.foreach { (c,n) =>
  268.           //if (n != 0)
  269.           printf("%d: '%c' %d\n", pid, c, n)
  270.         }*/
  271.       }
  272.     }
  273.     //val expected = snippet.groupBy(c => c).map { case (c,cs) => s": '$c' ${cs.length}" }.toSet
  274.     val actual = lms.core.utils.captureOut(snippet.eval("ARG")) // drop pid, since we don't know many here
  275.     val expected = actual
  276.  
  277.     assert { actual == expected }
  278.     check("wordcount_seq", snippet.code, "c")
  279.   }
  280.  
  281.   test("wordcount_staged_par") {
  282.     //@virtualize
  283.     val snippet = new MPIDriver[String,Unit] {
  284.  
  285.       def StringScanner(input: String) = new {
  286.         val data = uncheckedPure[Array[Char]](unit(input))
  287.         val pos = var_new(0)
  288.         def next(d: Rep[Char]) = {
  289.           val start: Rep[Int] = pos // force read
  290.           while (data(pos) != d) pos += 1
  291.           val len:Rep[Int] = pos - start
  292.           pos += 1
  293.           RString(stringFromCharArray(data,start,len), len)
  294.         }
  295.         def hasNext = pos < input.length
  296.       }
  297.       trait DataLoop {
  298.         def foreach(f: RString => Unit): Unit
  299.       }
  300.  
  301.       def parse(str: String) = new DataLoop {
  302.         val sc = StringScanner(str)
  303.         def foreach(f: RString => Unit) = {
  304.           while(sc.hasNext) {
  305.             f(sc.next(' '))
  306.           }
  307.         }
  308.       }
  309.  
  310.       def snippet(arg: Rep[String]): Rep[Unit] = {
  311.        
  312.         val input = "foo bar baz foo bar foo foo foo boom bang boom boom yum"
  313.         //val size = input.split(" ").length
  314.         //val mySize = size * (pid + 1) / nproc
  315.         val keySchema = Vector("word")
  316.         val dataSchema = Vector("#count")
  317.         val hm = new HashMapAgg(keySchema, dataSchema)
  318.        
  319.         // loop through string one word at a time
  320.         parse(input).foreach { word: RString =>
  321.           val key = Vector(word)
  322.           hm(key) += Vector(RInt(1))
  323.         }
  324.         hm.foreach{f: (Fields,Fields) =>
  325.           f._1.foreach{s: RField => s.print()}
  326.           f._2.foreach{s: RField => s.print()}
  327.        
  328.         }
  329.         println(hm.values.buf.getClass.getName)
  330.         // This will be the part for exchange
  331.         //HashMapAgg2.exchange() // Not done yet
  332.        
  333.         /*hm.foreach {
  334.           case (key, v) =>
  335.             key.head.asInstanceOf[RString].printLen()
  336.             printf(": ")
  337.             v.head.print()
  338.             printf("\n")
  339.         }*/
  340.      
  341.       }
  342.     }
  343.     //val expected = snippet.groupBy(c => c).map { case (c,cs) => s": '$c' ${cs.length}" }.toSet
  344.     val actual = lms.core.utils.captureOut(snippet.eval("ARG")) // drop pid, since we don't know many here
  345.     println(actual)
  346.     //assert { actual == expected }
  347.     //check("wordcount_seq", snippet.code, "c")
  348.   }
  349.  
  350.   test("hdfs scanner test"){
  351.     //@virtualize
  352.     val snippet = new MPIDriver[String,Unit] {
  353.       case class HdfsMeta() extends DistributedFileSystem{
  354.         def getDataDirs(): Array[String] = {
  355.           val dataDirsParam = getConf().get("dfs.datanode.data.dir");
  356.           dataDirsParam.split(",")
  357.         }
  358.         def getDataNodes() : Array[String] = {
  359.           val dataNodeStats = getDataNodeStats
  360.           val hosts = new Array[String](dataNodeStats.length)
  361.           println("------hdfs test------")
  362.           println(dataNodeStats.length)
  363.               //for (i <- 0 to dataNodeStats.length)
  364.                   //hosts(i) = dataNodeStats(i).getHostName()
  365.                 hosts
  366.         }
  367.         def printFileState(status: FileStatus): Unit = {
  368.           println("Metadata for: " + status.getPath)
  369.           println("Is Directory : " + status.isDirectory)
  370.           println("Is Symlink: " + status.isSymlink)
  371.           println("Encrypted: " + status.isEncrypted)
  372.           println("Length: " + status.getLen)
  373.           println("Replication: " + status.getReplication)
  374.           println("Blocksize: " + status.getBlockSize)
  375.         }
  376.     }
  377.       def snippet(arg: Rep[String]): Rep[Unit] = {
  378.         //val conf = new Configuration()
  379.         //val path = new Path("/user/hadoop/hdfs.test");
  380.         //conf.set( "fs.defaultFS", "hdfs://localhost:8020" )
  381.         //val hdfs = FileSystem.get(conf)
  382.         //println(hdfs.getClass)
  383.         //val dataDirsParam = conf.get("dfs.datanode.data.dir");
  384.         //val hdfsMeta = HdfsMeta()
  385.         //println(dataDirsParam.split(","))
  386.       }
  387.     }
  388.     //val actual = lms.core.utils.captureOut(snippet.eval("ARG"))
  389.     //rintln(actual)
  390.   }
  391.  
  392.   test("wordcount_mmap_hdfs") {
  393.     val actual = lms.core.utils.captureOut(new API {
  394.         /*val conf = new Configuration()
  395.         val path = new Path("/user/hadoop/test2");
  396.         conf.set( "fs.defaultFS", "hdfs://localhost:8020" )
  397.         val hdfs = FileSystem.get(conf)
  398.         val fsck = new DFSck(conf)
  399.         val cmds = Array("/user/hadoop/test2", "-files", "-blocks", "-locations")
  400.         val fileStatus = hdfs.getFileStatus(path)
  401.         println(fileStatus.getPath)
  402.         println(fileStatus.isDirectory)
  403.         println(fileStatus.isFile)
  404.         println(fileStatus.getLen)
  405.         println(hdfs.getDefaultBlockSize)
  406.         val blockLocations = hdfs.getFileBlockLocations(path, 0, fileStatus.getLen)
  407.         val blockLocation = blockLocations(0)
  408.         blockLocation.getNames.map(s => println(s))
  409.         fsck.run(cmds)*/
  410.         //println(conf.getLocalPath("/usr/local/Cellar/hadoop/hdfs/tmp", "/user/hadoop/test2"))
  411.         //val fsckCmd = "hdfs fsck /user/hadoop/test2 -files -blocks -locations" !;
  412.         //println(fsckCmd)
  413.       })
  414.     //println(actual)
  415.     }
  416. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement