Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
133
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 10.40 KB | None | 0 0
  1. /**
  2. Distributed Computing with MPI
  3. ==============================
  4.  
  5. In this tutorial, we build up to a distributed word count implementation using MPI.
  6.  
  7. Outline:
  8. <div id="tableofcontents"></div>
  9.  
  10. */
  11. package scala.lms.tutorial
  12.  
  13. import lms.core.stub._
  14. import lms.core.utils
  15. import lms.macros.SourceContext
  16. import lms.core.virtualize
  17. import scala.collection.{mutable,immutable}
  18. import org.apache.hadoop.fs.{FileSystem, Path, FileStatus, BlockLocation}
  19. import org.apache.hadoop.hdfs.DistributedFileSystem;
  20. import org.apache.hadoop.hdfs.protocol.DatanodeInfo;
  21. import org.apache.hadoop.hdfs.tools.DFSck;
  22. import org.apache.hadoop.conf.Configuration;
  23.  
  24. @virtualize
  25. class MPI2Test extends TutorialFunSuite {
  26.   val under = "mpi_"
  27.  
  28. /**
  29. MPI API
  30. -------
  31.  
  32. Using MPI requires a few additional headers and support functions, and programs are
  33. typically compiled and launched with the `mpicc` and `mpirun` tools. We define a
  34. subclass of `DslDriver` that contains the necessary infrastructure.
  35. */
  36.  
  37.   abstract class MPIDriver[T:Manifest,U:Manifest] extends DslDriverC[T,U] with ScannerLowerExp
  38.   with query_optc.QueryCompiler{ q =>
  39.     override val codegen = new DslGenC with CGenScannerLower with Run.CGenPreamble {
  40.       val IR: q.type = q
  41.     }
  42.     codegen.registerHeader("<mpi.h>")
  43.     codegen.registerHeader("<string.h>")
  44.     compilerCommand = "mpicc"
  45.     override def eval(a: T): Unit = {
  46.       import scala.sys.process._
  47.       import lms.core.utils._
  48.       val f1 = f; // compile!
  49.       def f2(a: T) = (s"mpirun /tmp/snippet $a": ProcessBuilder).lines.foreach(Console.println _)
  50.       time("eval")(f2(a))
  51.     }
  52.  
  53.     var pid: Rep[Int] = null
  54.     var nprocs: Rep[Int] = null
  55.  
  56.     override def wrapper(x: Rep[T]): Rep[U] = {
  57.       unchecked[Unit]("int argc = 0; char** argv = (char**)malloc(0); int provided");
  58.       unchecked[Unit]("MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided)")
  59.  
  60.       var nprocs1 = 0
  61.       unchecked[Unit]("MPI_Comm_size(MPI_COMM_WORLD, &", nprocs1, ")")
  62.  
  63.       var myrank = 0
  64.       unchecked[Unit]("MPI_Comm_rank(MPI_COMM_WORLD, &", myrank, ")")
  65.  
  66.       unchecked[Unit]("MPI_Request req")
  67.       unchecked[Unit]("MPI_Status status")
  68.  
  69.       pid = readVar(myrank)
  70.       nprocs = readVar(nprocs1)
  71.       val res = super.wrapper(x)
  72.  
  73.       unchecked[Unit]("MPI_Finalize()")
  74.       res
  75.     }
  76.  
  77.     def MPI_Issend(msg: Rep[Array[Int]], off: Rep[Int], len: Rep[Int], dst: Rep[Int]) = unchecked[Unit]("MPI_Issend(",msg," + (",off,"), ",len,", MPI_INT, ",dst,", 0, MPI_COMM_WORLD, &req)")
  78.     def MPI_Irecv(msg: Rep[Array[Int]], off: Rep[Int], len: Rep[Int], src: Rep[Int]) = unchecked[Unit]("MPI_Irecv(",msg," + (",off,"), ",len,", MPI_INT, ",src,", 0, MPI_COMM_WORLD, &req)")
  79.     def MPI_Barrier() = unchecked[Unit]("MPI_Barrier(MPI_COMM_WORLD)")
  80.   }
  81.  
  82. /**
  83. ### Staged and Distributed Implementation
  84.  
  85. TODO / Exercise: complete the implementation by writing an mmap-based
  86. Scanner (assuming each cluster node has access to a common file system)
  87. and adapting the hash table implementation used for `GroupBy` in the
  88. query tutorial to the distributed setting with communication along the
  89. lines used in the character histogram above.
  90. */
  91.   test("wordcount_staged_seq") {
  92.     //@virtualize
  93.     val snippet = new MPIDriver[String,Unit] {
  94.  
  95.       def StringScanner(input: String) = new {
  96.         val data = uncheckedPure[Array[Char]](unit(input))
  97.         val pos = var_new(0)
  98.         def next(d: Rep[Char]) = {
  99.           val start: Rep[Int] = pos // force read
  100.           while (data(pos) != d) pos += 1
  101.           val len:Rep[Int] = pos - start
  102.           pos += 1
  103.           RString(stringFromCharArray(data,start,len), len)
  104.         }
  105.         def hasNext = pos < input.length
  106.       }
  107.       trait DataLoop {
  108.         def foreach(f: RString => Unit): Unit
  109.       }
  110.  
  111.       def parse(str: String) = new DataLoop {
  112.         val sc = StringScanner(str)
  113.         def foreach(f: RString => Unit) = {
  114.           while(sc.hasNext) {
  115.             f(sc.next(' '))
  116.           }
  117.         }
  118.       }
  119.  
  120.       def snippet(arg: Rep[String]): Rep[Unit] = {
  121.         if (pid == 0) {
  122.           val input = "foo bar baz foo bar foo foo foo boom bang boom boom yum"
  123.           val keySchema = Vector("word")
  124.           val dataSchema = Vector("#count")
  125.           val hm = new HashMapAgg(keySchema, dataSchema)
  126.        
  127.         // loop through string one word at a time
  128.           parse(input).foreach { word: RString =>
  129.             val key = Vector(word)
  130.             hm(key) += Vector(RInt(1))
  131.           }
  132.        
  133.           hm.foreach {
  134.             case (key, v) =>
  135.               key.head.asInstanceOf[RString].printLen() // force cast to RString for printLen
  136.               printf(" ")
  137.               v.head.print()
  138.               printf("\n")
  139.           }
  140.         }
  141.  
  142.  
  143.         /*input.foreach { c =>
  144.           histogram(c) += 1
  145.         }
  146.  
  147.         histogram.exchange()
  148.  
  149.         histogram.foreach { (c,n) =>
  150.           //if (n != 0)
  151.           printf("%d: '%c' %d\n", pid, c, n)
  152.         }*/
  153.       }
  154.     }
  155.     //val expected = snippet.groupBy(c => c).map { case (c,cs) => s": '$c' ${cs.length}" }.toSet
  156.     val actual = lms.core.utils.captureOut(snippet.eval("ARG")) // drop pid, since we don't know many here
  157.     val expected = actual
  158.  
  159.     assert { actual == expected }
  160.     check("wordcount_seq", snippet.code, "c")
  161.   }
  162.  
  163.   test("wordcount_staged_par") {
  164.     //@virtualize
  165.     val snippet = new MPIDriver[String,Unit] {
  166.  
  167.       def StringScanner(input: String) = new {
  168.         val data = uncheckedPure[Array[Char]](unit(input))
  169.         val pos = var_new(0)
  170.         def next(d: Rep[Char]) = {
  171.           val start: Rep[Int] = pos // force read
  172.           while (data(pos) != d) pos += 1
  173.           val len:Rep[Int] = pos - start
  174.           pos += 1
  175.           RString(stringFromCharArray(data,start,len), len)
  176.         }
  177.         def hasNext = pos < input.length
  178.       }
  179.       trait DataLoop {
  180.         def foreach(f: RString => Unit): Unit
  181.       }
  182.  
  183.       def parse(str: String) = new DataLoop {
  184.         val sc = StringScanner(str)
  185.         def foreach(f: RString => Unit) = {
  186.           while(sc.hasNext) {
  187.             f(sc.next(' '))
  188.           }
  189.         }
  190.       }
  191.  
  192.       def snippet(arg: Rep[String]): Rep[Unit] = {
  193.        
  194.         val input = "foo bar baz foo bar foo foo foo boom bang boom boom yum"
  195.         //val size = input.split(" ").length
  196.         //val mySize = size * (pid + 1) / nproc
  197.         val keySchema = Vector("word")
  198.         val dataSchema = Vector("#count")
  199.         val hm = new HashMapAgg(keySchema, dataSchema)
  200.        
  201.         // loop through string one word at a time
  202.         parse(input).foreach { word: RString =>
  203.           val key = Vector(word)
  204.           hm(key) += Vector(RInt(1))
  205.         }
  206.         hm.foreach{f: (Fields,Fields) =>
  207.           f._1.foreach{s: RField => s.print()}
  208.           f._2.foreach{s: RField => s.print()}
  209.        
  210.         }
  211.         println(hm.values.buf.getClass.getName)
  212.         // This will be the part for exchange
  213.         //HashMapAgg2.exchange() // Not done yet
  214.        
  215.         /*hm.foreach {
  216.           case (key, v) =>
  217.             key.head.asInstanceOf[RString].printLen()
  218.             printf(": ")
  219.             v.head.print()
  220.             printf("\n")
  221.         }*/
  222.      
  223.       }
  224.     }
  225.     //val expected = snippet.groupBy(c => c).map { case (c,cs) => s": '$c' ${cs.length}" }.toSet
  226.     val actual = lms.core.utils.captureOut(snippet.eval("ARG")) // drop pid, since we don't know many here
  227.     println(actual)
  228.     //assert { actual == expected }
  229.     //check("wordcount_seq", snippet.code, "c")
  230.   }
  231.  
  232.   trait API{
  233.  
  234.     trait Scanner {
  235.       def next(d: Char): String
  236.       def hasNext: Boolean
  237.     }
  238.   }
  239.  
  240.   test("hdfs scanner test"){
  241.     //@virtualize
  242.     val snippet = new MPIDriver[String,Unit] {
  243.       case class HdfsMeta() extends DistributedFileSystem{
  244.         def getDataDirs(): Array[String] = {
  245.           val dataDirsParam = getConf().get("dfs.datanode.data.dir");
  246.           dataDirsParam.split(",")
  247.         }
  248.         def getDataNodes() : Array[String] = {
  249.           val dataNodeStats = getDataNodeStats
  250.           val hosts = new Array[String](dataNodeStats.length)
  251.           println("------hdfs test------")
  252.           println(dataNodeStats.length)
  253.               //for (i <- 0 to dataNodeStats.length)
  254.                   //hosts(i) = dataNodeStats(i).getHostName()
  255.                 hosts
  256.         }
  257.         def printFileState(status: FileStatus): Unit = {
  258.           println("Metadata for: " + status.getPath)
  259.           println("Is Directory : " + status.isDirectory)
  260.           println("Is Symlink: " + status.isSymlink)
  261.           println("Encrypted: " + status.isEncrypted)
  262.           println("Length: " + status.getLen)
  263.           println("Replication: " + status.getReplication)
  264.           println("Blocksize: " + status.getBlockSize)
  265.         }
  266.     }
  267.       def snippet(arg: Rep[String]): Rep[Unit] = {
  268.         //val conf = new Configuration()
  269.         //val path = new Path("/user/hadoop/hdfs.test");
  270.         //conf.set( "fs.defaultFS", "hdfs://localhost:8020" )
  271.         //val hdfs = FileSystem.get(conf)
  272.         //println(hdfs.getClass)
  273.         //val dataDirsParam = conf.get("dfs.datanode.data.dir");
  274.         //val hdfsMeta = HdfsMeta()
  275.         //println(dataDirsParam.split(","))
  276.       }
  277.     }
  278.     //val actual = lms.core.utils.captureOut(snippet.eval("ARG"))
  279.     //rintln(actual)
  280.   }
  281.  
  282.   test("wordcount_mmap_hdfs") {
  283.     val actual = lms.core.utils.captureOut(new API {
  284.         /*val conf = new Configuration()
  285.         val path = new Path("/user/hadoop/test2");
  286.         conf.set( "fs.defaultFS", "hdfs://localhost:8020" )
  287.         val hdfs = FileSystem.get(conf)
  288.         val fsck = new DFSck(conf)
  289.         val cmds = Array("/user/hadoop/test2", "-files", "-blocks", "-locations")
  290.         val fileStatus = hdfs.getFileStatus(path)
  291.         println(fileStatus.getPath)
  292.         println(fileStatus.isDirectory)
  293.         println(fileStatus.isFile)
  294.         println(fileStatus.getLen)
  295.         println(hdfs.getDefaultBlockSize)
  296.         val blockLocations = hdfs.getFileBlockLocations(path, 0, fileStatus.getLen)
  297.         val blockLocation = blockLocations(0)
  298.         blockLocation.getNames.map(s => println(s))
  299.         fsck.run(cmds)*/
  300.         //println(conf.getLocalPath("/usr/local/Cellar/hadoop/hdfs/tmp", "/user/hadoop/test2"))
  301.         //val fsckCmd = "hdfs fsck /user/hadoop/test2 -files -blocks -locations" !;
  302.         //println(fsckCmd)
  303.       })
  304.     //println(actual)
  305.     }
  306. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement