spark使用RDD遊標scan hash

import java.util
import java.util.Map

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.{Partition, SparkContext, TaskContext}
import redis.clients.jedis.{Jedis, ScanParams, ScanResult}

class RedisRDD(sc: SparkContext, val redisHashKeyList: Array[String]) extends RDD[(String, String, String)](sc, Seq.empty) with Logging {

    val redisUrl = sc.getConf.get("redis.host")
    val redisPort = sc.getConf.get("redis.port")

    override def getPartitions: Array[Partition] = redisHashKeyList.zipWithIndex.map(t => new RedisKeyPartition(t._1, t._2))

    override def compute(partition: Partition, context: TaskContext): Iterator[(String, String, String)] = {
        val jedis = new Jedis(redisUrl, redisPort.toInt, 2000, 10000)
        new RedisScanIterator(partition.asInstanceOf[RedisKeyPartition], jedis)
    }

    class RedisScanIterator(partition: RedisKeyPartition, jedis: Jedis) extends Iterator[(String, String, String)] {

        // 遊標初始值爲0
        var cursor: String = ScanParams.SCAN_POINTER_START
        val scanParams: ScanParams = new ScanParams
        // 使用sscan命令獲取1000條數據,使用cursor遊標記錄位置,下次循環使用
        scanParams.count(1000)

        var dataIterator: util.Iterator[Map.Entry[String, String]] = null
        var entry: Map.Entry[String, String] = null

        this.readBatch()
        log.info("count = {}", jedis.hlen(partition.getHashKey()))

        def readBatch() = {
            log.info("RedisScanIterator readBatch .....")
            val scanResult: ScanResult[util.Map.Entry[String, String]] = jedis.hscan(partition.getHashKey(), cursor, scanParams)
            cursor = scanResult.getCursor
            val result = scanResult.getResult
            dataIterator = result.iterator()
            log.info("RedisScanIterator readBatch size = {},  cursor = {} .....", result.size(), cursor)
        }

        override def hasNext: Boolean = {
            //            log.info("RedisScanIterator hasNext .....")
            if (dataIterator.hasNext) {
                true
            } else if("0".equals(cursor)) {
                false
            } else {
                this.readBatch()
                dataIterator.hasNext
            }
        }

        override def next(): (String, String, String) = {
            //            log.info("RedisScanIterator next .....")
            val entry = dataIterator.next()
            (partition.getHashKey(), entry.getKey, entry.getValue)
        }
    }

    class RedisKeyPartition(hashKey: String, idx: Int) extends Partition {
        override def index: Int = this.idx
        def getHashKey() = hashKey
    }

}


class RedisContext(spark: SparkSession) extends Serializable {
    implicit def fromRedisHash(redisHashKeyList: Array[String]) = new RedisRDD(spark.sparkContext, redisHashKeyList)
}
package object redis {
  implicit def of(spark: SparkSession): RedisContext =  new RedisContext(spark)
}
  • object RedisTest {
    
        def main(args: Array[String]): Unit = {
            import com.xx.xxx.xxx.redis._
            val spark = SparkSession
                .builder()
                .config("redis.host", "10.xx.xx.xx")
                .config("redis.port", "32101")
                .master("local[*]")
                .getOrCreate()
    
            import spark.implicits._
    
            // 讀取數據
            println(spark.fromRedisHash(Array("hash_ids0")).toDF().show(1000))
            println(spark.fromRedisHash(Array("hash_ids0")).toDF().count())
    
        }
    
    
    
    }
    

     

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章