Advertisement
Guest User

Untitled

a guest
Sep 3rd, 2015
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.84 KB | None | 0 0
  1. package com.datastax.spark.connector.rdd
  2.  
  3. import org.apache.spark.metrics.InputMetricsUpdater
  4.  
  5. import com.datastax.driver.core.Session
  6. import com.datastax.spark.connector._
  7. import com.datastax.spark.connector.cql._
  8. import com.datastax.spark.connector.rdd.reader._
  9. import com.datastax.spark.connector.util.CqlWhereParser.{EqPredicate, InListPredicate, InPredicate, RangePredicate}
  10. import com.datastax.spark.connector.util.{CountingIterator, CqlWhereParser}
  11. import com.datastax.spark.connector.writer._
  12. import com.datastax.spark.connector.util.Quote._
  13. import org.apache.spark.rdd.RDD
  14. import org.apache.spark.{Partition, TaskContext}
  15.  
  16. import scala.reflect.ClassTag
  17.  
  18. /**
  19. * An [[org.apache.spark.rdd.RDD RDD]] that will do a selecting join between `left` RDD and the specified
  20. * Cassandra Table This will perform individual selects to retrieve the rows from Cassandra and will take
  21. * advantage of RDDs that have been partitioned with the
  22. * [[com.datastax.spark.connector.rdd.partitioner.ReplicaPartitioner]]
  23. *
  24. * @tparam L item type on the left side of the join (any RDD)
  25. * @tparam R item type on the right side of the join (fetched from Cassandra)
  26. */
  27. class CassandraLeftJoinRDD[L, R] private[connector](
  28. left: RDD[L],
  29. val keyspaceName: String,
  30. val tableName: String,
  31. val connector: CassandraConnector,
  32. val columnNames: ColumnSelector = AllColumns,
  33. val joinColumns: ColumnSelector = PartitionKeyColumns,
  34. val where: CqlWhereClause = CqlWhereClause.empty,
  35. val limit: Option[Long] = None,
  36. val clusteringOrder: Option[ClusteringOrder] = None,
  37. val readConf: ReadConf = ReadConf())(
  38. implicit
  39. val leftClassTag: ClassTag[L],
  40. val rightClassTag: ClassTag[R],
  41. @transient val rowWriterFactory: RowWriterFactory[L],
  42. @transient val rowReaderFactory: RowReaderFactory[R])
  43. extends CassandraRDD[(L, Seq[R])](left.sparkContext, left.dependencies)
  44. with CassandraTableRowReaderProvider[R] {
  45.  
  46. override type Self = CassandraJoinRDD[L, R]
  47.  
  48. override protected val classTag = rightClassTag
  49.  
  50. override protected def copy(
  51. columnNames: ColumnSelector = columnNames,
  52. where: CqlWhereClause = where,
  53. limit: Option[Long] = limit,
  54. clusteringOrder: Option[ClusteringOrder] = None,
  55. readConf: ReadConf = readConf,
  56. connector: CassandraConnector = connector): Self = {
  57.  
  58. new CassandraJoinRDD[L, R](
  59. left = left,
  60. keyspaceName = keyspaceName,
  61. tableName = tableName,
  62. connector = connector,
  63. columnNames = columnNames,
  64. joinColumns = joinColumns,
  65. where = where,
  66. limit = limit,
  67. clusteringOrder = clusteringOrder,
  68. readConf = readConf)
  69. }
  70.  
  71. lazy val joinColumnNames: Seq[ColumnRef] = joinColumns match {
  72. case AllColumns => throw new IllegalArgumentException(
  73. "Unable to join against all columns in a Cassandra Table. Only primary key columns allowed.")
  74. case PartitionKeyColumns =>
  75. tableDef.partitionKey.map(col => col.columnName: ColumnRef)
  76. case SomeColumns(cs @ _*) =>
  77. checkColumnsExistence(cs)
  78. cs.map {
  79. case c: ColumnRef => c
  80. case _ => throw new IllegalArgumentException(
  81. "Unable to join against unnamed columns. No CQL Functions allowed.")
  82. }
  83. }
  84.  
  85. override def count(): Long = {
  86. columnNames match {
  87. case SomeColumns(_) =>
  88. logWarning("You are about to count rows but an explicit projection has been specified.")
  89. case _ =>
  90. }
  91.  
  92. val counts =
  93. new CassandraJoinRDD[L, Long](
  94. left = left,
  95. connector = connector,
  96. keyspaceName = keyspaceName,
  97. tableName = tableName,
  98. columnNames = SomeColumns(RowCountRef),
  99. joinColumns = joinColumns,
  100. where = where,
  101. limit = limit,
  102. clusteringOrder = clusteringOrder,
  103. readConf= readConf)
  104.  
  105. counts.map(_._2).reduce(_ + _)
  106. }
  107.  
  108. /** This method will create the RowWriter required before the RDD is serialized.
  109. * This is called during getPartitions */
  110. protected def checkValidJoin(): Seq[ColumnRef] = {
  111. val partitionKeyColumnNames = tableDef.partitionKey.map(_.columnName).toSet
  112. val primaryKeyColumnNames = tableDef.primaryKey.map(_.columnName).toSet
  113. val colNames = joinColumnNames.map(_.columnName).toSet
  114.  
  115. // Initialize RowWriter and Query to be used for accessing Cassandra
  116. rowWriter.columnNames
  117. singleKeyCqlQuery.length
  118.  
  119. def checkSingleColumn(column: ColumnRef): Unit = {
  120. require(
  121. primaryKeyColumnNames.contains(column.columnName),
  122. s"Can't pushdown join on column $column because it is not part of the PRIMARY KEY")
  123. }
  124.  
  125. // Make sure we have all of the clustering indexes between the 0th position and the max requested
  126. // in the join:
  127. val chosenClusteringColumns = tableDef.clusteringColumns
  128. .filter(cc => colNames.contains(cc.columnName))
  129. if (!tableDef.clusteringColumns.startsWith(chosenClusteringColumns)) {
  130. val maxCol = chosenClusteringColumns.last
  131. val maxIndex = maxCol.componentIndex.get
  132. val requiredColumns = tableDef.clusteringColumns.takeWhile(_.componentIndex.get <= maxIndex)
  133. val missingColumns = requiredColumns.toSet -- chosenClusteringColumns.toSet
  134. throw new IllegalArgumentException(
  135. s"Can't pushdown join on column $maxCol without also specifying [ $missingColumns ]")
  136. }
  137. val missingPartitionKeys = partitionKeyColumnNames -- colNames
  138. require(
  139. missingPartitionKeys.isEmpty,
  140. s"Can't join without the full partition key. Missing: [ $missingPartitionKeys ]")
  141.  
  142. joinColumnNames.foreach(checkSingleColumn)
  143. joinColumnNames
  144. }
  145.  
  146. lazy val rowWriter = implicitly[RowWriterFactory[L]].rowWriter(
  147. tableDef, joinColumnNames.toIndexedSeq)
  148.  
  149. def on(joinColumns: ColumnSelector): CassandraJoinRDD[L, R] = {
  150. new CassandraJoinRDD[L, R](
  151. left = left,
  152. connector = connector,
  153. keyspaceName = keyspaceName,
  154. tableName = tableName,
  155. columnNames = columnNames,
  156. joinColumns = joinColumns,
  157. where = where,
  158. limit = limit,
  159. clusteringOrder = clusteringOrder,
  160. readConf = readConf)
  161. }
  162.  
  163. //We need to make sure we get selectedColumnRefs before serialization so that our RowReader is
  164. //built
  165. lazy val singleKeyCqlQuery: (String) = {
  166. val whereClauses = where.predicates.flatMap(CqlWhereParser.parse)
  167. val joinColumns = joinColumnNames.map(_.columnName)
  168. val joinColumnPredicates = whereClauses.collect {
  169. case EqPredicate(c, _) if joinColumns.contains(c) => c
  170. case InPredicate(c) if joinColumns.contains(c) => c
  171. case InListPredicate(c, _) if joinColumns.contains(c) => c
  172. case RangePredicate(c, _, _) if joinColumns.contains(c) => c
  173. }.toSet
  174.  
  175. require(
  176. joinColumnPredicates.isEmpty,
  177. s"""Columns specified in both the join on clause and the where clause.
  178. |Partition key columns are always part of the join clause.
  179. |Columns in both: ${joinColumnPredicates.mkString(", ")}""".stripMargin
  180. )
  181.  
  182. logDebug("Generating Single Key Query Prepared Statement String")
  183. logDebug(s"SelectedColumns : $selectedColumnRefs -- JoinColumnNames : $joinColumnNames")
  184. val columns = selectedColumnRefs.map(_.cql).mkString(", ")
  185. val joinWhere = joinColumnNames.map(_.columnName).map(name => s"${quote(name)} = :$name")
  186. val limitClause = limit.map(limit => s"LIMIT $limit").getOrElse("")
  187. val orderBy = clusteringOrder.map(_.toCql(tableDef)).getOrElse("")
  188. val filter = (where.predicates ++ joinWhere).mkString(" AND ")
  189. val quotedKeyspaceName = quote(keyspaceName)
  190. val quotedTableName = quote(tableName)
  191. val query =
  192. s"SELECT $columns " +
  193. s"FROM $quotedKeyspaceName.$quotedTableName " +
  194. s"WHERE $filter $limitClause $orderBy"
  195. logDebug(s"Query : $query")
  196. query
  197. }
  198.  
  199. /**
  200. * When computing a CassandraPartitionKeyRDD the data is selected via single CQL statements
  201. * from the specified C* Keyspace and Table. This will be preformed on whatever data is
  202. * available in the previous RDD in the chain.
  203. */
  204. override def compute(split: Partition, context: TaskContext): Iterator[(L, Seq[R])] = {
  205. val session = connector.openSession()
  206. implicit val pv = protocolVersion(session)
  207. val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(consistencyLevel)
  208. val bsb = new BoundStatementBuilder[L](rowWriter, stmt, pv, where.values)
  209. val metricsUpdater = InputMetricsUpdater(context, readConf)
  210. val rowIterator = fetchIterator(session, bsb, left.iterator(split, context))
  211. val countingIterator = new CountingIterator(rowIterator, limit)
  212.  
  213. context.addTaskCompletionListener { (context) =>
  214. val duration = metricsUpdater.finish() / 1000000000d
  215. logDebug(
  216. f"Fetched ${countingIterator.count} rows " +
  217. f"from $keyspaceName.$tableName " +
  218. f"for partition ${split.index} in $duration%.3f s.")
  219. session.close()
  220. }
  221. countingIterator
  222. }
  223.  
  224. private def fetchIterator(
  225. session: Session,
  226. bsb: BoundStatementBuilder[L],
  227. lastIt: Iterator[L]): Iterator[(L, Seq[R])] = {
  228.  
  229. val columnNamesArray = selectedColumnRefs.map(_.selectedAs).toArray
  230. implicit val pv = protocolVersion(session)
  231. for (leftSide <- lastIt ) yield {
  232. val rightSide = {
  233. val rs = session.execute(bsb.bind(leftSide))
  234. val iterator = new PrefetchingResultSetIterator(rs, fetchSize)
  235. iterator.map(rowReader.read(_, columnNamesArray))
  236. })
  237. (leftSide, rightSide)
  238. }
  239. }
  240.  
  241. override protected def getPartitions: Array[Partition] = {
  242. verify()
  243. checkValidJoin()
  244. left.partitions
  245. }
  246.  
  247. override def getPreferredLocations(split: Partition): Seq[String] = left.preferredLocations(split)
  248.  
  249. override def toEmptyCassandraRDD: EmptyCassandraRDD[(L, Seq[R])] =
  250. new EmptyCassandraRDD[(L, Seq[R])](
  251. sc = left.sparkContext,
  252. keyspaceName = keyspaceName,
  253. tableName = tableName,
  254. columnNames = columnNames,
  255. where = where,
  256. limit = limit,
  257. clusteringOrder = clusteringOrder,
  258. readConf = readConf)
  259. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement