Flink DataStream API之State

DataStream API 之State

無論StructuredStreaming還是Flink在流處理的過程中都有一個“有狀態計算“的概念,那麼到底什麼是有狀態計算,有狀態計算應用到什麼場景,在Flink的DataStream API中如何使用狀態,以及在Flink中狀態是如何管理的,在這篇文章中,我們一點一點來學習。

1 什麼是有狀態計算

在講什麼是有狀態計算之前,先簡單說一下什麼是無狀態計算,在我理解,無狀態計算是指本次計算結果與之前輸出無關的計算。比如說,設備開關量的問題,假設我消息隊列中存放的消息是每個設備的開關量信息,包含:設備ID,以及設備的開關狀態(開啓狀態爲1,關閉狀態爲0),我們需求是隻要設備狀態爲0我們就標記爲設備異常了需要告警。

輸入:

{
    "id": "divice-1",
    "status": "1"
}
{
    "id": "divice-2",
    "status": "0"
}

輸出:

{
    "id": "divice-1",
    "alarm": "false"
}
{
    "id": "divice-2",
    "alarm": "true"
}

可以發現,無論輸入有多少條,它的輸出只與當前輸出數據有關,這樣的計算就是無狀態計算。

那麼什麼是有狀態計算呢,再舉個例子,同樣是設備開關量的問題,現在需求是,假設數據是時間有序的,如果設備之前的開啓狀態,現在的處於關閉狀態,即由1變爲0,我們認爲該設備異常了,需要告警了。那麼我們在計算當前輸出的時候,怎麼拿到之前的輸出呢?這時候就需要狀態了,我們可以把之前的輸入作爲狀態保存下來,在每次計算的時候拿出之前的狀態做比較,然後進行輸出。

輸入:

{
    "id": "divice-1",
    "status": "1"
}
{
    "id": "divice-1",
    "status": "0"
}
{
    "id": "divice-2",
    "status": "0"
}
{
    "id": "divice-2",
    "status": "0"
}

輸出:

{
    "id": "divice-1",
    "alarm": "true"
}
{
    "id": "divice-2",
    "alarm": "false"
}

2 有狀態計算的應用場景

下面舉幾個常見的狀態計算的應用場景

  • 流式去重:上游系統中存在重複數據,需要先進行重複過濾,最簡單的,在狀態中記錄所有主鍵,然後根據狀態中是否包含主鍵信息,來判斷是否重複。
  • 窗口計算:以10分鐘爲一個窗口,進行詞頻統計,我們需要把這10分鐘的數據統計結果先保存下來,等到窗口計算結束被觸發之後,再將結果輸出。
  • 機器學習/深度學習:如訓練的模型以及當前模型的參數也是一種狀態,機器學習可能每次都用有一個數據集,需要在數據集上進行學習,對模型進行一個反饋。
  • 訪問歷史數據:需要與之前的數據進行對比,上面舉得設備開關量的問題,將歷史數據放到狀態裏,與之對比。

3 Flink的DataStream中使用狀態

3.1 Flink中的狀態類型

Flink中的狀態有兩種:Managed State、Raw State。Managed State 是有Flink Runtime自動管理的State,而Raw State是原生態State,兩者區別如下表所示:

Managed State Raw State
狀態管理方式 Flin Runtime管理,自動存儲,自動恢復,在內存管理上有優化 需要用戶自己管理,自己序列化
狀態數據結構 Value、List、Map等 byte[]
推薦使用場景 大多數情況都可以使用 當 Managed State 不夠用時,比如需要自定義 Operator 時,推薦使用 Raw State

3.2 Keyed State & Operator State

Flink提供兩種基本狀態:Keyed State、Operator State

Keyed State Operator State
使用 只能在KeyedStream上的算子中 可在所有算子中使用,常用於source,例如FlinkKafkaConsumer
state對應關係 每個Key對應一個state,一個Operatory實例處理多個Key,訪問相應的多個State 一個Operator實例對應一個State
併發改變,分配方式 State隨着Key在實例間遷移 均勻分配、合併得全量
訪問方式 通過 RuntimeContext 訪問,這需要 Operator 是一個Rich Function 自己實現 CheckpointedFunction 或 ListCheckpointed 接口
支持的數據結構 ValueState、ListState、ReducingState、AggregatingState 和 MapState ListState

3.3 使用Managed Keyed State

keyed state需要在KeyedStream算子中使用,支持ValueState、ListState、ReducingState、AggregatingState 和 MapState這幾種數據類型,這幾種狀態數據類型的差異如下表所示:

狀態數據類型 訪問接口 差異體現
ValueState 單個值 update(T)
T value()
儲存單個值,值類型不限定
MapState Map put(UK key,UV value)
putAll(Map<UK,UV> map)
remove(UK key)
boolean contains(UK key)
UV get(UK key)
Iterable<Map.Entry> entries()
Iterator<Map.Entry> iterator()
Iterable<UK> keys()
Iterable<UV> values()
儲存類型爲Map,需要注意的是在 MapState 中的 key 和 Keyed state 中的 key 不是同一個
ListState List add(T)
addAll(List<T>)
update<UK> keys()
Iterable<UK> values()
儲存類型爲List
ReducingState 單個值 add(T)
addAll(List<T>)
update<UK> keys()
T get()
繼承ListState但狀態數據類型上是單個值,原因在於其中的 add 方法不是把當前的元素追加到列表中,而是把當前元素直接更新進了 Reducing 的結果中。輸入輸出類型相同。
AggregatingStatte 單個值 add(IN)
OUT get()
類似ReducingState,但是輸入輸出類型可以不同

爲方便演示這幾種狀態類型的實際運用,下面將分別舉幾個例子,有些應用場景有些牽強,只要領會其用意即可。

3.3.1 ValueState

ValueState爲單值類型,我們可以通過update(T)方法更新值,通過value()方法獲取該值。

3.3.1.1 獲取 ValueState

要使用ValueState,需要從RuntimeContext中獲取,所以需要實現RichFunction,在open()方法中通過getRuntimeContext獲取RuntimeContext,最後通過getState()獲取ValueState。

    override def open(parameters: Configuration): Unit = {
      // get state from RuntimeContext
      state = getRuntimeContext
        .getState(new ValueStateDescriptor[AvgState]("avgState", createTypeInformation[AvgState]))
    }

getState裏需要傳入ValueStateDescriptor實例,無論是ValueState、MapState、ListState、ReducingState、還是AggregatingState,它們的Descriptor都繼承自StateDescriptor,構造器方法相同的。如上代碼,我們是通過name和typeInfo構建的實例,ValueStateDescriptor有共有三種構造器方法:

構造器一:傳入name,以及typeClass

	public ValueStateDescriptor(String name, Class<T> typeClass) {
		super(name, typeClass, null);
	}

假如我們的狀態數據類型爲case class,如下所示定義

case class AvgState(count: Int, sum: Double)

我們可以通過classOf[AvgState]獲取typeClass,使用此構造器創建實例如下:

new ValueStateDescriptor("avgState",classOf[AvgState])

構造器二:傳入name,以及typeInfo

	public ValueStateDescriptor(String name, TypeInformation<T> typeInfo) {
		super(name, typeInfo, null);
	}

typeInfo我們可以通過import org.apache.flink.streaming.api.scala.createTypeInformation方法創建

new ValueStateDescriptor[AvgState]("avgState", createTypeInformation[AvgState])

構造器三:傳入name,以及typeSerializer

	public ValueStateDescriptor(String name, TypeSerializer<T> typeSerializer) {
		super(name, typeSerializer, null);
	}

serializer可以通過繼承TypeSerializer自定義實現,可以通過內置的KryoSerializer以及其它TypeSerializer創建

new ValueStateDescriptor[AvgState]("avgState", new KryoSerializer(classOf[AvgState], getRuntimeContext.getExecutionConfig))

3.3.1.2 使用ValueState實現移動平均

需求:

不考慮數據時序亂序問題,實現簡單移動平均,每來到一個數就計算其整體平均值。

思路:

使用ValueState保存中間狀態AvgState,該狀態包含兩個值,sum:目前所有數據的總和,count:目前所有數據的個數,然後sum/count求出平均值,數據進入後狀態count+1,狀態sum+當前數據,然後求其均值。

實現:

定義輸入輸出格式都爲case class

輸入數據格式

  /**
   * 設備事件
   *
   * @param id    設備ID
   * @param value 設備數據
   */
  case class DeviceEvent(id: String, value: Double)

輸出數據格式

  /**
   * 設備移動均值
   *
   * @param id  設備ID
   * @param avg 設備均值
   */
  case class DeviceAverage(id: String, avg: Double)

狀態存儲格式

  /**
   * 均值狀態
   *
   * @param count 數據個數
   * @param sum   數據總和
   */
  case class AvgState(count: Int, sum: Double)

繼承RichMapFunction獲取狀態,並實現map方法

  /**
   * 繼承 RichMapFunction 實現map方法
   */
  class MoveAverage extends RichMapFunction[DeviceEvent, DeviceAverage] {
    private var state: ValueState[AvgState] = _

    override def open(parameters: Configuration): Unit = {
      // get state from RuntimeContext
      state = getRuntimeContext
        .getState(new ValueStateDescriptor[AvgState]("avgState", new KryoSerializer(classOf[AvgState], getRuntimeContext.getExecutionConfig)))
    }

    override def map(value: DeviceEvent): DeviceAverage = {
      // get or init state value.
      val stateValue = Option(state.value()).getOrElse(AvgState(0, 0.0))
      // update newStateValue to runtime
      val newStateValue = AvgState(stateValue.count + 1, stateValue.sum + value.value)
      state.update(newStateValue)
      DeviceAverage(value.id, newStateValue.sum / newStateValue.count)
    }
  }

從socket獲取實時數據,將數據轉換爲DeviceEvent格式,然後根據id分組,最後執行自定義map方法

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    // get input data
    val streamText: DataStream[String] = env.socketTextStream(
      Option(params.get("hostname")).getOrElse("localhost"),
      Option(params.get("port")).getOrElse("9090").toInt)

    val streamData: DataStream[DeviceEvent] = streamText.map(text => {
      val token = text.split(" ")
      DeviceEvent(token(0), token(1).toDouble)
    })

    streamData.keyBy(_.id).map(new MoveAverage()).print("Moving avg")

    env.execute("ManagedKeyedValueStateExample")
  }

上述使用的自定義RichMap方法也可以簡單的使用mapWithState實現

    // simple
    streamData.keyBy(_.id).mapWithState[DeviceAverage, AvgState] {
      {
        case (in: DeviceEvent, None) => (DeviceAverage(in.id, in.value), Some(AvgState(1, in.value)))
        case (in: DeviceEvent, state: Some[AvgState]) =>
          val newStateValue = AvgState(state.get.count + 1, state.get.sum + in.value)
          (DeviceAverage(in.id, newStateValue.sum / newStateValue.count), Some(newStateValue))
      }
    }.print("Simple moving avg")

3.3.2 MapState

MapState存儲類型爲Map,我們可以使用Map特有的方法,比如put、get、keys() 、putAll等。

3.3.2.1 獲取MapState

MapState的獲取方式與ValueState一樣,在RuntimeContext裏通過getMapState獲取,其中需要創建MapStateDescriptor實例,該實例同樣有三種方式構建:typeClass、typeInfo、typeSerializer。

    override def open(parameters: Configuration): Unit = {
      state = getRuntimeContext.getMapState(
        new MapStateDescriptor[Long, DeviceEvent](
          "alarmMapState",
          createTypeInformation[Long],
          createTypeInformation[DeviceEvent]))
    }

3.3.2.2 使用MapState實現開關量異常判別

需求:

假設設備信息包含id、timestamp、status,設備數據存在時序亂序的問題,需要實時判別設備狀態是否異常,判別依據是,如果當前時間狀態爲0,上一條時間狀態爲1,即狀態從1變爲0了,我們則判斷此設備變爲異常。

思路:

由於考慮亂序問題,這裏需要緩存過去一定量的數據,簡單起見,我們保存10個數據,這10條數據以timstamp爲key,status爲value保存到MapState中。當時間戳爲t的新數據到達之後,獲取緩存數據的keys()轉爲TreeSet,如果當前數據的狀態爲0,則查找出keys中t的前一個值,如果存在前一個值,且前一個值狀態爲1,則轉爲異常事件發送給下游。如果當前數據的狀態爲1,則查找出keys中t的後一個值,如果存在後一個值,且後一個值狀態爲0,則轉換下一個值爲異常事件發送給下游。

實現:

定義輸入事件格式:

  /**
   * 設備事件數據結構
   *
   * @param id        設備ID
   * @param timestamp 事件時間
   * @param status    設備狀態
   */
  case class DeviceEvent(id: String, timestamp: Long, status: Int)

定義輸出事件格式

  /**
   * 設備告警數據結構
   *
   * @param id            設備ID
   * @param timestamp     事件時間
   * @param lastTimestamp 上一條記錄時間
   */
  case class DeviceAlarm(id: String, timestamp: Long, lastTimestamp: Long)

繼承RichFlatMapFunction實現flatmap方法,實現開關量判別邏輯

class AlarmAnalyzer extends RichFlatMapFunction[DeviceEvent, DeviceAlarm] {
    private var state: MapState[Long, DeviceEvent] = _


    override def open(parameters: Configuration): Unit = {
      state = getRuntimeContext.getMapState(
        new MapStateDescriptor[Long, DeviceEvent](
          "alarmMapState",
          createTypeInformation[Long],
          createTypeInformation[DeviceEvent]))
    }

    override def flatMap(value: DeviceEvent, out: Collector[DeviceAlarm]): Unit = {
      // get all keys and transform to tree set.
      val keys: util.TreeSet[Long] = new util.TreeSet[Long](state.keys().asInstanceOf[util.Collection[Long]])
      // clear
      clear(keys)

      val currentKey = value.timestamp
      keys.add(currentKey)
      state.put(currentKey, value)
      // 如果當前事件狀態爲0,查找是否包含上一個事件,如果上一個事件狀態爲1,則轉換爲異常事件將其發送給下游
      if (value.status == 0) {
        val lastKey = Some(keys.lower(currentKey))
        if (lastKey.get!=null && state.get(lastKey.get).status == 1) {
          out.collect(DeviceAlarm(value.id, currentKey, lastKey.get))
        }
      } else {
        // 查找下一個事件,如果下一個事件爲0,則轉換爲異常事件發送給下游
        val nextKey = Some(keys.higher(currentKey))
        if (nextKey.get!=null && state.get(nextKey.get).status == 0) {
          out.collect(DeviceAlarm(value.id, nextKey.get, currentKey))
        }
      }

    }

    def clear(keys: util.TreeSet[Long], size: Int = 10): Unit = {
      if (keys.size() == size) {
        val firstKey = keys.first()
        state.remove(firstKey)
        keys.remove(keys.first())
      }
    }
  }

從Socket中實時獲取數據,轉換爲DeviceEvent類型,然後根據id進行分組,執行flatmap函數

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    // get input data
    val streamText: DataStream[String] = env.socketTextStream(Option(params.get("hostname")).getOrElse("localhost"),
      Option(params.get("port")).getOrElse("9090").toInt)

    val streamData: DataStream[DeviceEvent] = streamText.map(text => {
      val token = text.split(" ")
      DeviceEvent(token(0), token(1).toLong, token(2).toInt)
    })

    streamData.keyBy(_.id).flatMap(new AlarmAnalyzer()).print()

    env.execute("ManagedKeyedMapStateExample")
  }

輸入數據:

device-1 1 1
device-1 2 0
device-1 3 1
device-1 5 1
device-1 4 0

結果如下所示:

3.3.3 ListState

ListState顧名思義,存儲結構爲List,可以存儲多個值。我們可以使用List的特有方法,如add,values()等

3.3.3.1 獲取ListState

ListState方法異曲同工,在RuntimeContext裏通過getListState方法獲取,需要傳入ListStateDescriptor實例,

ListStateDescriptor也有三種。

    override def open(parameters: Configuration): Unit = {
      state = getRuntimeContext
        .getListState(new ListStateDescriptor[Double](
          "varianceState",
          createTypeInformation[Double]))
    }

3.3.3.2 使用ListState實現累計方差計算

需求:

實時累計5條數據後做一次方差,然後輸出

思路:

使用ListState存儲歷史數據,當數據達到5條之後,將其全部取出,計算方法,然後輸出到下游。

實現:

輸入數據格式

  /**
   * 設備事件
   *
   * @param id    設備ID
   * @param value 設備值
   */
  case class DeviceEvent(id: String, value: Double)

輸出數據格式

  /**
   * 設備方差事件
   *
   * @param id       設備ID
   * @param values   累計所有值
   * @param variance 方差
   */
  case class DeviceVariance(id: String, values: List[Double], variance: Double)

繼承RichFlatMapFunction實現flatmap方法,完成計算方差邏輯。

  class VarianceCalculator extends RichFlatMapFunction[DeviceEvent, DeviceVariance] {
    private var state: ListState[Double] = _
    private val countSize: Int = 5

    override def open(parameters: Configuration): Unit = {
      state = getRuntimeContext
        .getListState(new ListStateDescriptor[Double](
          "varianceState",
          createTypeInformation[Double]))
    }

    override def flatMap(value: DeviceEvent, out: Collector[DeviceVariance]): Unit = {
      import scala.collection.JavaConverters._
      state.add(value.value)
      val currentStateList: Iterable[Double] = state.get().asScala
      if (currentStateList.size == countSize) {
        out.collect(DeviceVariance(value.id, currentStateList.toList, variance(currentStateList)))
        state.clear()
      }
    }

    /**
     * 計算方差
     * @param values 數據列表
     * @return 方差
     */
    def variance(values: Iterable[Double]): Double = {
      val avg = values.sum / values.size.toDouble
      math.sqrt(values.map(x => math.pow(x - avg, 2)).sum / values.size)
    }

  }

從socket裏獲取數據,並轉換爲DeviceEvent,根據id分組之後,調用flatmap方法。

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    // get input data
    val streamText: DataStream[String] = env.socketTextStream(Option(params.get("hostname")).getOrElse("localhost"),
      Option(params.get("port")).getOrElse("9090").toInt)

    val streamData: DataStream[DeviceEvent] = streamText.map(text => {
      val token = text.split(" ")
      DeviceEvent(token(0), token(1).toDouble)
    })

    streamData.keyBy(_.id).flatMap(new VarianceCalculator()).print()
    env.execute("ManagedKeyedListStateExample")
  }

nc -lk 9090 輸入數據:

device-1 1
device-1 2
device-1 3
device-1 4
device-1 5
device-1 6
device-1 7
device-1 8
device-1 9
device-1 10

結果:

3.3.3 ReducingState

ReductingState的存儲類型也爲單個值,需要用戶實現reduce方法,當調用add()添加數據時,會指定自定義的reduce方法。

3.3.3.1 獲取ReducingState

在RuntimeContext中通過getReducingState()方法獲取,需要構建ReducingStateDescriptor實例,構造器不同於之前,除了name、typeinfo之前還需要傳入自定義的reduce實例。

    override def open(parameters: Configuration): Unit = {
      // get state from runtime context
      state = getRuntimeContext
        .getReducingState(new ReducingStateDescriptor[Double](
          "sumAccumulatorState",
          new SumReducing(),
          createTypeInformation[Double]))
    }

3.3.3.2 使用ReducingState計算累加和

需求:

使用ReducingState實時計算數據總和

思路:

實現ReducFunction,將最近兩個狀態相加。

實現:

輸入數據格式

  /**
   * 設備事件
   *
   * @param id    設備ID
   * @param value 設備值
   */
  case class DeviceEvent(id: String, value: Double)

輸出數據格式

  /**
   * 設備累加和
   *
   * @param id  設備ID
   * @param sum 設備值
   */
  case class DeviceSum(id: String, sum: Double)

繼承ReduceFunction實現reduce方法

  class SumReducing extends ReduceFunction[Double] {
    override def reduce(value1: Double, value2: Double): Double = value1 + value2
  }

繼承RichMapFunction實現map方法,完成累加和的邏輯

  class SumAccumulator extends RichMapFunction[DeviceEvent, DeviceSum] {
    private var state: ReducingState[Double] = _

    override def open(parameters: Configuration): Unit = {
      // get state from runtime context
      state = getRuntimeContext
        .getReducingState(new ReducingStateDescriptor[Double](
          "sumAccumulatorState",
          new SumReducing(),
          createTypeInformation[Double]))
    }

    override def map(value: DeviceEvent): DeviceSum = {
      state.add(value.value)
      DeviceSum(value.id, state.get())
    }
  }

從socket中獲取數據,並轉換爲DeviceEvent,然後根據id分組,調用自定義map方法。

def main(args: Array[String]): Unit = {
  val params: ParameterTool = ParameterTool.fromArgs(args)

  // set up execution environment
  val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

  // make parameters available in the web interface
  env.getConfig.setGlobalJobParameters(params)

  // get input data
  val streamText: DataStream[String] = env.socketTextStream(Option(params.get("hostname")).getOrElse("localhost"),
    Option(params.get("port")).getOrElse("9090").toInt)

  val streamData: DataStream[DeviceEvent] = streamText.map(text => {
    val token = text.split(" ")
    DeviceEvent(token(0), token(1).toDouble)
  })

  streamData.keyBy(_.id).map(new SumAccumulator()).print()
  env.execute("ManagedKeyedReducingStateExample")
}

輸入:

device-1 1
device-1 2
device-1 2.2

結果:

3.3.4 AggregatingState

AggregatingState與ReducingState類似,也是一種單個值的聚合狀態。具有以下特點:

  • 可以對輸入值,中間聚合和結果類型使用不同類型,以支持各種聚合類型
  • 支持分佈式聚合:可以將不同的中間聚合合併在一起,以允許預聚合/最終聚合優化。

3.3.4.1 獲取AggregatingState

AggregatingState也是通過RuntimeContext的getAggregatingStata方法獲取,同樣需要傳入AggregatingStateDescriptor實例,構建AggregatingStateDescriptor實例時需要傳入自定義的AggregatingFunction。

    override def open(parameters: Configuration): Unit = {
      state = getRuntimeContext.getAggregatingState(new AggregatingStateDescriptor[Long, AverageAccumulator, Double](
        "rateAccumulatorState",
        new AvgAggregating(),
        createTypeInformation[AverageAccumulator]
      ))
    }

3.3.4.2 使用AggregatingState實現移動平均

需求:

利用AggregatingState實時計算設備均值

思路:

思路與ValueState的均值計算相同

實現:

輸入數據類型

 /**
   * 設備事件
   *
   * @param id    設備ID
   * @param value 設備值
   */
  case class DeviceEvent(id: String, value: Long)

輸出數據類型

  /**
   * 設備均值
   *
   * @param id  設備ID
   * @param avg 平均值
   */
  case class DeviceAvg(id: String, avg: Double)

聚合累加器定義

case class AverageAccumulator(sum: Long, count: Int)

實現自定義的聚合方法

  class AvgAggregating extends AggregateFunction[Long, AverageAccumulator, Double] {

    override def createAccumulator(): AverageAccumulator = AverageAccumulator(0L, 0)

    override def add(value: Long, accumulator: AverageAccumulator): AverageAccumulator =
      AverageAccumulator(accumulator.sum + value, accumulator.count + 1)

    override def getResult(accumulator: AverageAccumulator): Double = accumulator.sum.toDouble / accumulator.count.toDouble

    override def merge(a: AverageAccumulator, b: AverageAccumulator): AverageAccumulator =
      AverageAccumulator(a.sum + b.sum, a.count + b.count)
  }

實現自定義的RichMapFunction

  class MovingAvg extends RichMapFunction[DeviceEvent, DeviceAvg] {
    private var state: AggregatingState[Long, Double] = _

    override def open(parameters: Configuration): Unit = {
      state = getRuntimeContext.getAggregatingState(new AggregatingStateDescriptor[Long, AverageAccumulator, Double](
        "rateAccumulatorState",
        new AvgAggregating(),
        createTypeInformation[AverageAccumulator]
      ))
    }

    override def map(value: DeviceEvent): DeviceAvg = {
      state.add(value.value)
      DeviceAvg(value.id, state.get())
    }
  }

從Socket中獲取數據,轉換爲DeviceEvent類型,然後根據id分組,調用自定義map方法。

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    // get input data
    val streamText: DataStream[String] = env.socketTextStream(Option(params.get("hostname")).getOrElse("localhost"),
      Option(params.get("port")).getOrElse("9090").toInt)

    val streamData: DataStream[DeviceEvent] = streamText.map(text => {
      val token = text.split(" ")
      DeviceEvent(token(0), token(1).toLong)
    })

    streamData.keyBy(_.id).map(new MovingAvg()).print()

    env.execute("ManagedKeyedAggregatingStateExample")
  }

輸入:

device-1 1
device-1 2
device-1 3

3.3.5 狀態生命週期

在流處理的過程中,如果狀態不斷累積,很容易造成OOM,所以我們需要一種機制,來及時清理掉不需要的狀態。對於Keyed State來說,自Flink 1.6之後引入了Time-To-Live (TTL)機制,能夠友好的幫助我們自動清理掉過期狀態。關於狀態生命週期更多的內容可以參考:《如何應對飛速增長的狀態?Flink State TTL 概述》

3.3.5.1 StateTtlConfig

爲了使用狀態TTL,必須先構建StateTtlConfig配置對象。然後可以通過傳遞配置在任何狀態描述符中啓用TTL功能。

  val ttlConfig: StateTtlConfig = StateTtlConfig
    // 設置過期時間,10s後過期
    .newBuilder(Time.seconds(10))
    // ttl 刷新機制,默認在創建和寫狀態時刷新ssl
    .setUpdateType(StateTtlConfig.UpdateType.OnReadAndWrite)
    // 表示對已過期但還未被清理掉的狀態如何處理
    .setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
    //過期對象的清理策略
    .cleanupIncrementally(1, true)
    .build

StateTtlConfig參數說明:

下面根據StateTtlConfig構造器參數,分別描述一下參數作用。

	private StateTtlConfig(
		UpdateType updateType,
		StateVisibility stateVisibility,
		TimeCharacteristic timeCharacteristic,
		Time ttl,
		CleanupStrategies cleanupStrategies) {
		this.updateType = Preconditions.checkNotNull(updateType);
		this.stateVisibility = Preconditions.checkNotNull(stateVisibility);
		this.timeCharacteristic = Preconditions.checkNotNull(timeCharacteristic);
		this.ttl = Preconditions.checkNotNull(ttl);
		this.cleanupStrategies = cleanupStrategies;
		Preconditions.checkArgument(ttl.toMilliseconds() > 0,
			"TTL is expected to be positive");
	}
  • updateType: 表示狀態時間戳的更新的時機,是一個 Enum 對象。如果設置爲 Disabled,則表明不更新時間戳;如果設置爲 OnCreateAndWrite,則表明當狀態創建或每次寫入時都會更新時間戳;如果設置爲 OnReadAndWrite,則除了在狀態創建和寫入時更新時間戳外,讀取也會更新狀態的時間戳。
  • stateVisibility: 表示對已過期但還未被清理掉的狀態如何處理,也是 Enum 對象。如果設置爲 ReturnExpiredIfNotCleanedUp,那麼即使這個狀態的時間戳表明它已經過期了,但是隻要還未被真正清理掉,就會被返回給調用方;如果設置爲 NeverReturnExpired,那麼一旦這個狀態過期了,那麼永遠不會被返回給調用方,只會返回空狀態,避免了過期狀態帶來的干擾。
  • TimeCharacteristic 以及 TtlTimeCharacteristic:表示 State TTL 功能所適用的時間模式,仍然是 Enum 對象。前者已經被標記爲 Deprecated(廢棄),推薦新代碼採用新的 TtlTimeCharacteristic 參數。截止到 Flink 1.8,只支持 ProcessingTime 一種時間模式,對 EventTime 模式的 State TTL 支持還在開發中
  • CleanupStrategies:表示過期對象的清理策略,目前來說有三種 Enum 值。當設置爲 FULL_STATE_SCAN_SNAPSHOT 時,對應的是 EmptyCleanupStrategy 類,表示對過期狀態不做主動清理,當執行完整快照(Snapshot / Checkpoint)時,會生成一個較小的狀態文件,但本地狀態並不會減小。唯有當作業重啓並從上一個快照點恢復後,本地狀態纔會實際減小,因此可能仍然不能解決內存壓力的問題。爲了應對這個問題,Flink 還提供了增量清理的枚舉值,分別是針對 Heap StateBackend 的 INCREMENTAL_CLEANUP(對應 IncrementalCleanupStrategy 類),以及對 RocksDB StateBackend 有效的 ROCKSDB_COMPACTION_FILTER(對應 RocksdbCompactFilterCleanupStrategy 類). 對於增量清理功能,Flink 可以被配置爲每讀取若干條記錄就執行一次清理操作,而且可以指定每次要清理多少條失效記錄;對於 RocksDB 的狀態清理,則是通過 JNI 來調用 C++ 語言編寫的 FlinkCompactionFilter 來實現,底層是通過 RocksDB 提供的後臺 Compaction 操作來實現對失效狀態過濾的。

3.3.5.2 開啓TTL

想要在狀態中啓用TTL,需要在構建的StateDescriptor實例中,調用enableTimeToLive方法

      val listStateDescriptor = new ListStateDescriptor("listState", createTypeInformation[Long])
      listStateDescriptor.enableTimeToLive(ttlConfig)
      state = getRuntimeContext.getListState(listStateDescriptor)

3.3.5.3 使用TTL例子

package com.hollysys.flink.streaming.state.managed.keyed

import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.common.state.{ListState, ListStateDescriptor, StateTtlConfig}
import org.apache.flink.api.common.time.Time
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment, createTypeInformation}

/**
 * Created by shirukai on 2019/8/27 4:23 下午
 * 帶有生命週期的狀態,我們可以給狀態設置過期時間
 * https://cloud.tencent.com/developer/article/1452844
 */
object TimeToLiveStateExample {
  val ttlConfig: StateTtlConfig = StateTtlConfig
    // 設置過期時間,10s後過期
    .newBuilder(Time.seconds(10))
    // ttl 刷新機制,默認在創建和寫狀態時刷新ttl
    // 枚舉類型。有三種機制:Disabled、OnReadAndWrite、OnReadAndWrite
    .setUpdateType(StateTtlConfig.UpdateType.OnReadAndWrite)
    // 表示對已過期但還未被清理掉的狀態如何處理
    .setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
    //過期對象的清理策略
    .cleanupIncrementally(1, true)
    .build


  case class DeviceEvent(id: String, value: Long)

  case class DeviceList(id: String, list: List[Long])


  class ListCollector extends RichMapFunction[DeviceEvent, DeviceList] {
    private var state: ListState[Long] = _

    override def open(parameters: Configuration): Unit = {
      val listStateDescriptor = new ListStateDescriptor("listState", createTypeInformation[Long])
      listStateDescriptor.enableTimeToLive(ttlConfig)
      state = getRuntimeContext.getListState(listStateDescriptor)
    }

    override def map(value: DeviceEvent): DeviceList = {
      import scala.collection.JavaConverters._
      state.add(value.value)
      DeviceList(value.id, state.get().asScala.toList)
    }
  }

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    // get input data
    val streamText: DataStream[String] = env.socketTextStream(Option(params.get("hostname")).getOrElse("localhost"),
      Option(params.get("port")).getOrElse("9090").toInt)

    val streamData: DataStream[DeviceEvent] = streamText.map(text => {
      val token = text.split(" ")
      DeviceEvent(token(0), token(1).toLong)
    })

    streamData.keyBy(_.id).map(new ListCollector()).print()

    env.execute("TimeToLiveStateExample")
  }
}

3.4 使用Managed Operator State

上面我們介紹瞭如何使用Managed Keyed State,通過RuntimeContext的getXXXState方法可以獲取到不同的KeyedState,這必須要在KeyedDataStream中使用,如果在DataStream中使用的話會報如下異常:

那麼在普通的Operator中我們如何使用狀態呢?官方提供了兩種Operator State使用方法,繼承CheckpointedFunction和ListCheckpointed<T extends Serializable>接口。

3.4.1 繼承CheckpointedFunction實現有狀態Operator

package com.hollysys.flink.streaming.state.managed.operator


import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
import org.apache.flink.streaming.api.functions.sink.SinkFunction
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala._

import scala.collection.mutable.ListBuffer

/**
 * Created by shirukai on 2019/8/29 10:06 上午
 * 繼承CheckpointedFunction獲取狀態
 * 實現有狀態的Sink
 */
object StateByCheckpointedExample {

  case class DeviceEvent(id: String, value: Double)


  class BufferSink(threshold: Int = 2) extends SinkFunction[DeviceEvent] with CheckpointedFunction {

    @transient
    private var checkpointedState: ListState[DeviceEvent] = _
    private val bufferedElements = ListBuffer[DeviceEvent]()

    override def invoke(value: DeviceEvent, context: SinkFunction.Context[_]): Unit = {
      bufferedElements += value
      println(bufferedElements)
      if (bufferedElements.size == threshold) {
        for (element <- bufferedElements) {
          // send it to the sink
          println(s"BufferSink: $element")
        }
        bufferedElements.clear()
      }
    }

    // 當檢查點被請求快照時調用,用以保存當前狀態
    override def snapshotState(context: FunctionSnapshotContext): Unit = {
      checkpointedState.clear()
      for (element <- bufferedElements) {
        checkpointedState.add(element)
      }
    }

    // 當並行實例被創建時調用,用以初始化狀態
    override def initializeState(context: FunctionInitializationContext): Unit = {
      val descriptor = new ListStateDescriptor[DeviceEvent](
        "buffered-elements",
        createTypeInformation[DeviceEvent])

      // 通過getOperatorStateStore方法獲取operator狀態
      // getListState
      // getUnionListState 獲取全量狀態,會合並所有並行實例狀態
      checkpointedState = context.getOperatorStateStore.getListState(descriptor)
      import scala.collection.JavaConverters._
      // 如果從先前的快照恢復狀態,則返回true
      if (context.isRestored) {
        // 將恢復後的狀態刷到ListBuffer裏
        for (element <- checkpointedState.get().asScala) {
          bufferedElements += element
        }
      }
    }
  }

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    env.enableCheckpointing(1000)

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    // get input data
    val streamText: DataStream[String] = env.socketTextStream(Option(params.get("hostname")).getOrElse("localhost"),
      Option(params.get("port")).getOrElse("9090").toInt)

    val streamData: DataStream[DeviceEvent] = streamText.map(text => {
      val token = text.split(" ")
      DeviceEvent(token(0), token(1).toDouble)
    })

    streamData.addSink(new BufferSink(2))


    env.execute("StateByCheckpointedExample")
  }
}

3.4.2 繼承ListCheckpointed實現有狀態Operator

package com.hollysys.flink.streaming.state.managed.operator

import java.util
import java.util.Collections
import java.util.concurrent.TimeUnit

import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.runtime.state.StateBackend
import org.apache.flink.runtime.state.filesystem.FsStateBackend
import org.apache.flink.streaming.api.CheckpointingMode
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed
import org.apache.flink.streaming.api.environment.CheckpointConfig
import org.apache.flink.streaming.api.functions.source.{RichParallelSourceFunction, SourceFunction}
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.api.scala._

/**
 * Created by shirukai on 2019/8/29 1:48 下午
 * 繼承ListCheckpointedExample獲取狀態
 * 實現有狀態的Source
 */
object StateByListCheckpointedExample {

  case class DeviceEvent(id: String, value: Long)

  case class Offset(value: Long) extends Serializable


  class CounterSource extends RichParallelSourceFunction[DeviceEvent] with ListCheckpointed[Offset] {

    @volatile
    private var isRunning = true

    private var offset = 0L

    override def run(ctx: SourceFunction.SourceContext[DeviceEvent]): Unit = {
      val lock = ctx.getCheckpointLock
      while (isRunning) {
        // output and state update are atomic
        lock.synchronized({
          ctx.collect(DeviceEvent(s"Device-$offset", offset))
          offset += 1
          TimeUnit.SECONDS.sleep(1)
        })
      }
    }

    override def cancel(): Unit = isRunning = false


    // 恢復到之前檢查點的狀態
    override def restoreState(state: util.List[Offset]): Unit = {
      if (!state.isEmpty) offset = state.get(0).value
    }

    // 返回當前狀態用以保存到快照中
    override def snapshotState(checkpointId: Long, timestamp: Long): util.List[Offset] =
      Collections.singletonList(Offset(offset))

  }

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    env.enableCheckpointing(1000)
        .setStateBackend(new FsStateBackend("file:///Users/shirukai/hollysys/repository/learn-demo-flink/data/checkpoint").asInstanceOf[StateBackend])
    env.getCheckpointConfig.enableExternalizedCheckpoints(CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION)
    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    env.addSource(new CounterSource()).setParallelism(1).print()

    env.execute("StateByListCheckpointedExample")
  }

}

3.5 廣播狀態模式

以下關於“什麼是廣播狀態”內容引用於文章《Apache Flink 中廣播狀態的實用指南》

廣播狀態可以用於通過一個特定的方式來組合並共同處理兩個事件流。第一個流的事件被廣播到另一個 operator 的所有併發實例,這些事件將被保存爲狀態。另一個流的事件不會被廣播,而是發送給同一個 operator 的各個實例,並與廣播流的事件一起處理。廣播狀態非常適合兩個流中一個吞吐大,一個吞吐小,或者需要動態修改處理邏輯的情況。

package com.hollysys.flink.streaming.state.broadcast

import org.apache.flink.api.common.state.MapStateDescriptor
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala._
import org.apache.flink.util.Collector

import scala.collection.mutable

/**
 * Created by shirukai on 2019/8/29 4:18 下午
 * 廣播狀態示例
 */
object BroadcastStateExample {

  val ruleStateDescriptor = new MapStateDescriptor("rule-state",
    createTypeInformation[String],
    createTypeInformation[mutable.Map[String, RuleEvent]])

  case class DeviceEvent(id: String, value: Double)

  case class RuleEvent(id: String, ruleType: String, bind: String)

  case class DeviceWithRule(device: DeviceEvent, rule: RuleEvent)

  class DeviceWithRuleProcess extends KeyedBroadcastProcessFunction[String, DeviceEvent, RuleEvent, DeviceWithRule] {



    override def processElement(value: DeviceEvent, ctx: KeyedBroadcastProcessFunction[String, DeviceEvent, RuleEvent,
      DeviceWithRule]#ReadOnlyContext, out: Collector[DeviceWithRule]): Unit = {
      val ruleState = ctx.getBroadcastState(ruleStateDescriptor)
      // 如果數據包含規則
      if (ruleState.contains(value.id)) {
        val rules = ruleState.get(value.id)
        rules.foreach(rule => {
          out.collect(DeviceWithRule(value,rule._2))
        })
      }
    }

    override def processBroadcastElement(value: RuleEvent, ctx: KeyedBroadcastProcessFunction[String, DeviceEvent,
      RuleEvent, DeviceWithRule]#Context, out: Collector[DeviceWithRule]): Unit = {
      val ruleState = ctx.getBroadcastState(ruleStateDescriptor)
      val bindKey = value.bind
      if (ruleState.contains(bindKey)) {
        val bindRules = ruleState.get(bindKey)
        bindRules.put(value.id, value)
      } else {
        ruleState.put(bindKey, mutable.Map(value.id -> value))
      }
    }
  }

  def main(args: Array[String]): Unit = {
    val params: ParameterTool = ParameterTool.fromArgs(args)

    // set up execution environment
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    // make parameters available in the web interface
    env.getConfig.setGlobalJobParameters(params)

    val deviceText: DataStream[String] = env.socketTextStream(
      Option(params.get("device-hostname")).getOrElse("localhost"),
      Option(params.get("device-port")).getOrElse("9090").toInt)

    val ruleText: DataStream[String] = env.socketTextStream(
      Option(params.get("rule-hostname")).getOrElse("localhost"),
      Option(params.get("rule-port")).getOrElse("9091").toInt)

    val deviceEvents = deviceText.map(x => {
      val token = x.split(" ")
      DeviceEvent(token(0), token(1).toDouble)
    })

    val ruleEvents = ruleText.map(x => {
      val token = x.split(" ")
      RuleEvent(token(0), token(1), token(2))
    })


    val ruleBroadcastStream = ruleEvents.broadcast(ruleStateDescriptor)

    deviceEvents.keyBy(_.id).connect(ruleBroadcastStream).process(new DeviceWithRuleProcess()).print()

    env.execute("BroadcastStateExample")
  }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章