import java.util
import java.util.Map
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.{Partition, SparkContext, TaskContext}
import redis.clients.jedis.{Jedis, ScanParams, ScanResult}
class RedisRDD(sc: SparkContext, val redisHashKeyList: Array[String]) extends RDD[(String, String, String)](sc, Seq.empty) with Logging {
val redisUrl = sc.getConf.get("redis.host")
val redisPort = sc.getConf.get("redis.port")
override def getPartitions: Array[Partition] = redisHashKeyList.zipWithIndex.map(t => new RedisKeyPartition(t._1, t._2))
override def compute(partition: Partition, context: TaskContext): Iterator[(String, String, String)] = {
val jedis = new Jedis(redisUrl, redisPort.toInt, 2000, 10000)
new RedisScanIterator(partition.asInstanceOf[RedisKeyPartition], jedis)
}
class RedisScanIterator(partition: RedisKeyPartition, jedis: Jedis) extends Iterator[(String, String, String)] {
// 遊標初始值爲0
var cursor: String = ScanParams.SCAN_POINTER_START
val scanParams: ScanParams = new ScanParams
// 使用sscan命令獲取1000條數據,使用cursor遊標記錄位置,下次循環使用
scanParams.count(1000)
var dataIterator: util.Iterator[Map.Entry[String, String]] = null
var entry: Map.Entry[String, String] = null
this.readBatch()
log.info("count = {}", jedis.hlen(partition.getHashKey()))
def readBatch() = {
log.info("RedisScanIterator readBatch .....")
val scanResult: ScanResult[util.Map.Entry[String, String]] = jedis.hscan(partition.getHashKey(), cursor, scanParams)
cursor = scanResult.getCursor
val result = scanResult.getResult
dataIterator = result.iterator()
log.info("RedisScanIterator readBatch size = {}, cursor = {} .....", result.size(), cursor)
}
override def hasNext: Boolean = {
// log.info("RedisScanIterator hasNext .....")
if (dataIterator.hasNext) {
true
} else if("0".equals(cursor)) {
false
} else {
this.readBatch()
dataIterator.hasNext
}
}
override def next(): (String, String, String) = {
// log.info("RedisScanIterator next .....")
val entry = dataIterator.next()
(partition.getHashKey(), entry.getKey, entry.getValue)
}
}
class RedisKeyPartition(hashKey: String, idx: Int) extends Partition {
override def index: Int = this.idx
def getHashKey() = hashKey
}
}
class RedisContext(spark: SparkSession) extends Serializable {
implicit def fromRedisHash(redisHashKeyList: Array[String]) = new RedisRDD(spark.sparkContext, redisHashKeyList)
}
package object redis {
implicit def of(spark: SparkSession): RedisContext = new RedisContext(spark)
}
-
object RedisTest {
def main(args: Array[String]): Unit = {
import com.xx.xxx.xxx.redis._
val spark = SparkSession
.builder()
.config("redis.host", "10.xx.xx.xx")
.config("redis.port", "32101")
.master("local[*]")
.getOrCreate()
import spark.implicits._
// 讀取數據
println(spark.fromRedisHash(Array("hash_ids0")).toDF().show(1000))
println(spark.fromRedisHash(Array("hash_ids0")).toDF().count())
}
}