KM算法講解(含C++代碼)

假設有3個女的要嫁給三個男的,各有各的期望值。
如何讓期望值之和最大?
此時我們就要用到傳說中的km算法了。
這個算法本質上是貪心算法,怎麼算呢?
舉個例子吧
在這裏插入圖片描述
首先看女1,女1與男1間的邊權值+男1期望值=3,而3不等於女一的期望值,所以配對失敗。接着女1與男3間的邊權值+男3的期望值=4,剛好4與女1的期望值相等,配對成功!
在這裏插入圖片描述
接着讓女2找對象,匹配的過程就省略了,最後發現跟男3可以配對,而男3被女1佔了,女2對女1說:“你能不能降低一下期望值啊?”於是女1同意了,但是我們還是得讓女1跟男3可以配對,於是就將女1的期望值降低1,男3的期望值上升1,這樣他們還是能夠配對(男3挑剔了起來)。但是這樣女2又不能配對了,就將女2的期望值也降低。這是發現女1可以和男1配對,就將他們配起來。
在這裏插入圖片描述
接着幫女3找對象,發現女3無法跟任何人配對,就只好將她的期望值降1,讓她可以和男3配對。
在這裏插入圖片描述

此時女3發現男3被女2佔了,於是就勒索讓女2降低期望值,於是女2降低1期望值,爲了以後還能找男3,所以男3的期望值上升1,女3期望值也得隨之降1。這時女1的期望值也得降低1,因爲女1也要保持隨時可以與男3配對。 接着女2找上了男2,於是就跟男2配對。



此時三男三女都有了自己的對象了。(好開心,終於打完字了)


你以爲這就結束了?
不可能 ,還有例題呢:

Description

小W在八中開了一個兼職中心。現在他手下有N個工人。每個工人有N個工作可以選擇,於是每個人做每個工作的效率是不一樣的。做爲CEO的小W的任務就是給每個人分配一個工作,保證所有人效率之和是最大的。N<=200

Input

第一行給出數字N
接下來N行N列,代表每個人工作的效率。

Output

一個數字,代表最大效率之和

Sample Input

4
62 41 86 94
73 58 11 12
69 93 89 88
81 40 69 13

Sample Output

329

HINT



這就是km算法的模板題(雖然跟我舉的例子有點出入,但本質上還是相同的)。
代碼有點長,慢慢看啊
版本1:

#include<bits/stdc++.h>
using namespace std;
const int N=205;
int w[N][N];
int la[N],lb[N];
bool va[N],vb[N];
int match[N];
int delta,n;
void read() {
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d",&w[i][j]);
}
bool dfs(int x) {
    va[x]=1;
    for(int y=1;y<=n;y++)
        if(!vb[y])
            if(la[x]+lb[y]-w[x][y]==0) {
                vb[y]=1;
                if(!match[y]||dfs(match[y])) {
                    match[y]=x;
                    return true;
                }
            }
            else
                delta=min(delta,la[x]+lb[y]-w[x][y]);
    return false;
}
int KM() {
    for(int i=1;i<=n;i++) {
        la[i]=-(1<<30);
        lb[i]=0;
        for(int j=1;j<=n;j++)
            la[i]=max(la[i],w[i][j]);
    }
    for(int i=1;i<=n;i++)
        while(true) {
            memset(va,0,sizeof(va));
            memset(vb,0,sizeof(vb));
            delta=1<<30;
            if(dfs(i))
                break;
            for(int j=1;j<=n;j++) {
                if(va[j])
                    la[j]-=delta;
                if(vb[j])
                    lb[j]+=delta;
            }
        }
    int ans=0;
    for(int i=1;i<=n;i++)
        ans+=w[match[i]][i];
    return ans;
}
void write() {
    printf("%d\n",KM());
}
int main() {
    read();
    write();
}

版本2:

#include <bits/stdc++.h>
using namespace std;
int n;
int w[305][305];
int lx[305],ly[305];
int matched[305];
int slack[305];
bool s[305],t[305];
bool match(int i) {
    s[i]=1;
    for(int j=1; j<=n; j++) {
        int cnt=lx[i]+ly[j]-w[i][j];
        if(cnt==0&&!t[j]) {
            t[j]=1;
            if(!matched[j]||match(matched[j])) {
                matched[j]=i;
                return 1;
            }
        } else {
            slack[j]=min(slack[j],cnt);
        }
    }
    return 0;
}
void update() {
    int a=0x3f3f3f3f;
    for(int i=1; i<=n; i++) {
        if(!t[i])
            a=min(a,slack[i]);
    }
    for(int i=1; i<=n; i++) {
        if(s[i])lx[i]-=a;
        if(t[i])ly[i]+=a;
    }
 
}
void km() {
    memset(matched,0,sizeof(matched));
    memset(lx,0,sizeof(lx));
    memset(ly,0,sizeof(ly));
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            lx[i]=max(lx[i],w[i][j]);
        }
    }
    for(int i=1; i<=n; i++) {
        memset(slack,0x3f,sizeof(slack));
        while(1) {
            memset(s,0,sizeof(s));
            memset(t,0,sizeof(t));
            if(match(i))
                break;
            else
                update();
        }
    }
}
int main() {
    scanf("%d",&n);
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            scanf("%d",&w[i][j]);
        }
    }
    km();
    int ans=0;
    for(int i=1; i<=n; i++) {
        ans+=lx[i];
        ans+=ly[i];
    }
    printf("%d\n",ans);
 
    return 0;
}

如果有什麼地方沒講好或者是說講錯了,可以在評論區告訴我,我會看到後立馬修改。
如果覺得有什麼想法也可以在評論區說。
感謝觀看。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章