Advertisement
Guest User

Untitled

a guest
Aug 15th, 2017
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 2.87 KB | None | 0 0
  1. import java.sql.{Connection, DriverManager, SQLException}
  2.  
  3. @throws(classOf[SQLException])
  4. class MarkovGenerator(private var dbUrl: String, private val table: String = "markov"){
  5.  
  6.  
  7.   private var connection: Connection = _
  8.  
  9.   if(!dbUrl.startsWith("jdbc:sqlite:")){
  10.     dbUrl = "jdbc:sqlite:" + dbUrl
  11.   }
  12.  
  13.   try{
  14.     connection = DriverManager.getConnection(dbUrl)
  15.     val sql = s"CREATE TABLE IF NOT EXISTS $table(" +
  16.       " key TEXT PRIMARY KEY," +
  17.       " value TEXT);"
  18.     val stmt = connection.createStatement()
  19.     stmt.execute(sql)
  20.   }
  21.   catch{
  22.     case e: SQLException =>
  23.       throw e
  24.   }
  25.  
  26.   /**
  27.     * Resets the sql table
  28.     */
  29.   def resetTable(): Unit ={
  30.     var sql = s"DROP TABLE IF EXISTS $table"
  31.     var stmt = connection.createStatement()
  32.     stmt.execute(sql)
  33.  
  34.     sql = s"CREATE TABLE IF NOT EXISTS $table(" +
  35.       " key TEXT PRIMARY KEY," +
  36.       " value TEXT);"
  37.     stmt = connection.createStatement()
  38.     stmt.execute(sql)
  39.   }
  40.  
  41.   /**
  42.     * Prints the sql table
  43.     */
  44.   def printTable(): Unit ={
  45.     val sql = s"SELECT key, value FROM $table"
  46.     val rs = connection.createStatement().executeQuery(sql)
  47.     while (rs.next()){
  48.       println(rs.getString(1) + " === " + rs.getString(2) )
  49.     }
  50.   }
  51.  
  52.   private def hasChainedWords(wordPair: String): Boolean = {
  53.     val sql = s"SELECT key FROM $table WHERE key = ?"
  54.     val pstmt = connection.prepareStatement(sql)
  55.     pstmt.setString(1, wordPair)
  56.     val rs = pstmt.executeQuery()
  57.     rs.next()
  58.     rs.getRow > 0
  59.   }
  60.  
  61.   private def getChainedWordsFromPair(wordPair: String): Array[String] = {
  62.     val sql = s"SELECT key, value FROM $table WHERE key = ?"
  63.     val pstmt = connection.prepareStatement(sql)
  64.     pstmt.setString(1, wordPair)
  65.     val rs = pstmt.executeQuery()
  66.     rs.getString("value").split("\\s+")
  67.   }
  68.  
  69.   private def chainWordToPair(wordPair: String, word: String): Unit ={
  70.     val hasChained = hasChainedWords(wordPair)
  71.     val sql = if(hasChained){
  72.       s"UPDATE $table SET value = ? WHERE key = ? "
  73.     }
  74.     else{
  75.       s"INSERT INTO $table(value, key) VALUES(?,?)"
  76.     }
  77.  
  78.     var newWords = word
  79.  
  80.     if(hasChained){
  81.       val chained = getChainedWordsFromPair(wordPair)
  82.       if(chained.contains(word)) return
  83.       for(chainedWord <- chained){
  84.         newWords = newWords +  " " + chainedWord
  85.       }
  86.  
  87.     }
  88.     val pstmt = connection.prepareStatement(sql)
  89.     pstmt.setString(1, newWords)
  90.     pstmt.setString(2, wordPair)
  91.     pstmt.executeUpdate()
  92.   }
  93.  
  94.  
  95.  
  96.   /**
  97.     * Parses a sentence to be included in the chain
  98.     */
  99.   def parseSentence(sentence: String): Unit ={
  100.     val split = sentence.split("\\s+")
  101.     if(split.length > 1){
  102.       for(i <- split.indices){
  103.         if(i < split.length - 2){
  104.           val wordPair = split(i) + " " + split(i+1)
  105.           chainWordToPair(wordPair, split(i+2))
  106.         }
  107.       }
  108.  
  109.     }
  110.  
  111.  
  112.   }
  113.  
  114.  
  115. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement