點到三角形的最近點--二次規劃

問題如下在三角形中找一個點p到三角形內部最近的點(可能在邊上,頂點上或者內部)
在這裏插入圖片描述
這個當然可以用幾何的方法用條件判斷,這裏介紹建模爲優化問題

其實這個問題可以表示爲
argminxAxp22,s.t.x=1,x>=0 arg min _x \| Ax - p \|_2^2, \quad s.t. \sum_x=1, x>=0
就是三角形內部的點可以用頂點差值表示(重心座標),這樣就轉換成了最小二乘問題。

有趣的是其實這個問題的弱化版爲 找線段外一點到這個線段最近的位置。我就是通過這個簡單的例子最終發現單純型法算不出來。

求解這個其實可以經過一些推導轉化成類似線性規劃的問題,但是有個ux=0的情況,表示如果用單純性算法時,互補鬆弛變量不能同時作爲基變量 (推了很長時間就是無法實現)。

後面展示了用庫求解和用有效集的方法求解

1. 用二次規劃的庫求解

def problem(tri_p, px):
    At = tri_p
    A = At.T
    Atb = At.dot(px)
    AtA = At.dot(A)
    return AtA, Atb
from cvxopt  import solvers, matrix 
tri_p=np.array([
    [0, 0], [2, 0], [1, 2]
], np.float64)
px = np.array([1, 3], np.float64)

# 似乎必須是 float64
P, q = problem(tri_p, px)
P = matrix(P)
q = matrix(-q)
G = matrix(-np.eye(3))
h = matrix(np.zeros(3))
A = matrix(np.ones((1, 3)))
b = matrix(np.array([1.]))

sol = solvers.qp(P,q,G,h,A,b)

#print(sol)
print(sol['x'])

輸出如下
在這裏插入圖片描述

2.(失敗) 嘗試的單純型法

這個就不要看了,主要是推導了很長時間,換了很多種方法都沒有弄出來。
大體思路就是KKT條件後,構造等式約束,並構造一個目標函數。最後發現達不到最優點

def f5(tri_p, px):
    _A, _b = problem(tri_p, px)
    neg  = _b<0
    if neg.sum()>0:
        _A[neg] *= -1
        _b[neg] *= -1
        
    nall = 3+3+4+2
    A=np.zeros( (4, nall) )
    b=np.zeros( (4) )
    
    A[:3, :3] = _A
    A[:3, 3:3+3] = -np.eye(3)
    
    A[3, :3] = 1
    
    A[:3, 10] = -1
    A[:3, 11] = 1
    A[:4, 6:10] = np.eye(4)
    
    b[:3] = _b
    b[3] = 1
    
    enters = np.arange(6, 10)
    
    slack = np.arange(nall)
    
    slack[:3], slack[3:3+3] = slack[3:3+3].copy(), slack[:3].copy()
    slack[10:12] = 9
    
    z = np.zeros( nall + 1 )
    
    arg_type="max"
    #z[nall-4:-1] = -1
    z[6:10] = -1
    #z[nall-5:-1] = -1
    #z[nall-6] = 1
    
    # p(enters)
    # p(slack)
    # p(A)
    # p(z)
    
    for i in range(4):
        change_A_b(A, b, z, i, enters[i])
    
    # p(A)
    # p(b)
    # p(z)
    # exit()
    f5_solve(A, b, z, enters, (slack), arg_type, n_it=20)
    x = np.zeros( nall )
    x[enters] = b
    p("mx", x)
    return x[:3]


def f5_solve(A, b, z, enters, slack, arg_type, n_it=20):
    for i in  range(n_it):
        if not sel_enter_outer(A, b, z, enters, arg_type, slack): break
        

def sel_enter_outer(A, b, z, enters, arg_type="max", slack=None):
    r, c = A.shape
    slack_mp = None
    if len(slack)==2:
        slack, slack_mp = slack
    vc = np.zeros(c)
    for _c in range(c):
        ci = sel_enter_idx(z[:-1], arg_type, enters, vc)
        if ci < 0: return False
        vc[ci] = 1
        ri = sel_outer_idx(A[:, ci], b)
        if ri < 0: continue
        #p("ci, ri", ci, enters[ri])
        if slack is None: break
        else:
            if slack[ci]==enters[ri] or slack[ci] not in enters: 
                if slack_mp and ci in slack_mp:
                    if slack[ci]==enters[ri]:
                        if slack_mp[ci] not in enters: break
                        else: continue
                    else:
                        if slack_mp[ci]==enters[ri] or slack_mp[ci] not in enters: break
                else: break
    if _c==c: return False
    p("ci, ri, z", ci, enters[ri], z)
    change_A_b(A, b, z, ri, ci)
    enters[ri] = ci
    p("ck---")
    p(enters)
    #p(A)
    #p(b)
    p(z)
    
    return True

def change_A_b(A, b, z, ri, ci):
    b[ri] /= A[ri, ci]
    A[ri] /= A[ri, ci].copy()
    for i in range(len(b)):
        if i!=ri:
            b[i] -= A[i, ci]*b[ri]
            A[i] -= A[i, ci]*A[ri]
    z[-1] -= z[ci]*b[ri]
    z[:-1] -= z[ci]*A[ri]
    
    
# 0 是可以找的    
def sel_enter_idx(z, arg_type, enters, vc):
    if arg_type=="max": mm = -1
    elif arg_type=="min": mm = 1
    _i = -1
    for i in range(len(z)):
        if vc[i]: continue # 不合法的直接不要
        if i in enters: continue
        if arg_type=="max":
            if z[i]<=0: continue
            if mm < z[i]:
                mm = z[i]
                _i = i
        else:
            if z[i]>=0: continue
            if mm > z[i]:
                mm = z[i]
                _i = i
    return _i
    
def sel_outer_idx(a_c, b):
    mx=np.inf
    _i = -1
    for i in range(len(b)):
        if a_c[i]>0 and mx > (b[i]/a_c[i]):
            mx=b[i]/a_c[i]
            _i = i
    return _i

3. 有效集法 【本次最值得記錄的】

https://wenku.baidu.com/view/ad34f36079563c1ec5da71fc.html
其實之前感覺弄不出來的時候就想到找了,一開始感覺很複雜,所以不想看。後面仔細看了一下,感覺也很好了解。
整體思路就類似,高斯牛頓法,給個初始值,然後求dp,但是這裏主要多了個保證在可行域之中。

裏面應該求解那裏有重複計算的地方,這裏就不優化了。

def f7(tri_p, px):
    A, b = problem(tri_p, px)
    n_x = 3
    n_eq = 1
    n_neq = n_x
    n_max_Q = n_x + n_eq + n_neq # 2 個變量, 1個等式變量,最多2個不等式變量
    valid = np.zeros(n_max_Q, np.bool)
    Q = np.zeros( (n_max_Q, n_max_Q) )
    Ae = np.zeros( (n_eq, n_x) ) + 1 # task 有關
    be = np.zeros( (n_eq) ) + 1
    #Ane = np.zeros( (n_neq, n_x) )
    Ane = np.eye( n_x )
    bne = np.zeros( (n_neq) )
    
    Q[:n_x, :n_x] = A
    Q[:n_x, n_x:n_x+n_eq] = -Ae.T
    Q[n_x:n_x+n_eq, :n_x] = Ae
    Q[:n_x, n_x+n_eq:] = -Ane.T
    Q[n_x+n_eq:, :n_x] = Ane
    
    q = np.zeros( n_max_Q )
    q[:n_x] = b
    q[n_x:n_x+n_eq] = be
    q[n_x+n_eq:] = bne
    
    # p(Q)
    # p(q)
    # exit()
    
    valid[:n_x+n_eq] = True
    
    n_must = n_x+n_eq
    # task 有關
    x_init=np.zeros(n_x) + 1/n_x
    oA, ob = A, b
    
    #q[n_x:] = 0 # 才發現根本沒有用
    bi = n_x + n_eq
    _cnt=0
    while True:
        p("x:", x_init)
        d_f_x_k = oA.dot(x_init) - ob
        q[:n_x] = -d_f_x_k
        
        A = Q[valid][:, valid]
        b = q[valid].copy()
        b[n_x:] = 0
        
        # 有重複計算
        AtA = A.T.dot(A)
        Atb = A.T.dot(b)
        
        d = np.linalg.solve(AtA, Atb)
        #_cnt+=1
        #if _cnt>10: exit()
        p(_cnt, np.abs(d[:n_x]).sum())
        if np.abs(d[:n_x]).sum() < 1e-6:
            if d.shape[0] == bi: break
            mm = np.min(d[bi:])
            if mm>=0: break
            # 把最不能滿足不等式約束的去掉
            valid[bi:][valid[bi:]][np.argmin(d[bi:])] = False
            
        else:
            # 不爲0,基本就說明約束不夠, 達不到極值
            # step 1:保證更新之後,未參與的約束合法
            alpha = 1
            for i in range(bi, bi+n_neq):
                if not valid[i]:
                    # 對於不在可行解之中的
                    fm = Q[i, :n_x].dot(d[:n_x])
                    if fm<0: # 可能會出了可行域, 因爲不等式是 >= b, 所以如果fm小於0則可能 出現<b
                        # a^T(x+\alpah d) = b -> \alpha = (b-a^Tx)/
                        alpha = min(alpha, (q[i]-Q[i, :n_x].dot(x_init))/fm)
                        assert alpha!=0

            x_init += alpha*d[:n_x]
            
            # step2: 發現有新的有效約束則加入
            for i in range(bi, bi+n_neq):
                if not valid[i]:
                    if np.abs(q[i]-Q[i, :n_x].dot(x_init))<1e-9:
                        valid[i] = True
            
            #p("alpha", alpha, d)
            
    p(x_init) 
    return x_init

4. 調試時的其他代碼

4.1 雜記

def plot_triangle(A, B, C):
    x = [A[0], B[0], C[0], A[0]]
    y = [A[1], B[1], C[1], A[1]]
    plt.plot(x, y, linewidth=2)

def f_cmp(tri_p, px, x1, x2):
    r1 = _cac(tri_p, px, x1)
    r2 = _cac(tri_p, px, x2)
    p(r1, r2)
    abs = np.abs( r1 - r2 ).sum()
    return abs<1e-6

def plot_line(A, B):
    x = [A[0], B[0]]
    y = [A[1], B[1]]
    plt.plot(x, y, linewidth=2)

4.2 關於線段和點的

在這裏插入圖片描述

plt.gca().set_aspect('equal', adjustable='box')
plot_line(tri_p[0], tri_p[1])
plt.plot( px[0], px[1], "r." )
plt.plot( tri_p[0][0], tri_p[0][1], "g*" )
plt.plot( tri_p[1][0], tri_p[1][1], "y+" )
p1 = tri_p.T.dot(x1)
p2 = tri_p.T.dot(x2)
plt.plot( p1[0], p1[1], "b*" )
plt.plot( p2[0], p2[1], "r*" )
plt.show()
# 可以解析求解
def f6(tri_p, px):
    p1 = tri_p[0]
    p2 = tri_p[1]
    p21 = p2 - p1
    px1 = px - p1
    x = (p21*px1).sum()/(p21*p21).sum()
    if x<0: x=0
    if x>1: x=1
    return np.array([1-x, x])

# 也可以用有效集的方法
def f7(tri_p, px):
    A, b = problem(tri_p, px)
    n_x = 2
    n_eq = 1
    n_neq = n_x
    n_max_Q = n_x + n_eq + n_neq # 2 個變量, 1個等式變量,最多2個不等式變量
    valid = np.zeros(n_max_Q, np.bool)
    Q = np.zeros( (n_max_Q, n_max_Q) )
    Ae = np.zeros( (n_eq, n_x) ) + 1 # task 有關
    be = np.zeros( (n_eq) ) + 1
    #Ane = np.zeros( (n_neq, n_x) )
    Ane = np.eye( n_x )
    bne = np.zeros( (n_neq) )
    
    Q[:n_x, :n_x] = A
    Q[:n_x, n_x:n_x+n_eq] = -Ae.T
    Q[n_x:n_x+n_eq, :n_x] = Ae
    Q[:n_x, n_x+n_eq:] = -Ane.T
    Q[n_x+n_eq:, :n_x] = Ane
    
    q = np.zeros( n_max_Q )
    q[:n_x] = b
    q[n_x:n_x+n_eq] = be
    q[n_x+n_eq:] = bne
    
    # p(Q)
    # p(q)
    # exit()
    
    valid[:n_x+n_eq] = True
    
    n_must = n_x+n_eq
    # task 有關
    x_init=np.zeros(n_x) + 1/n_x
    oA, ob = A, b
    
    #q[n_x:] = 0 # 才發現根本沒有用
    bi = n_x + n_eq
    _cnt=0
    while True:
        p("x:", x_init)
        d_f_x_k = oA.dot(x_init) - ob
        q[:n_x] = -d_f_x_k
        
        A = Q[valid][:, valid]
        b = q[valid].copy()
        b[n_x:] = 0
        
        # 有重複計算
        AtA = A.T.dot(A)
        Atb = A.T.dot(b)
        
        d = np.linalg.solve(AtA, Atb)
        #_cnt+=1
        #if _cnt>10: exit()
        p(_cnt, np.abs(d[:n_x]).sum())
        if np.abs(d[:n_x]).sum() < 1e-6:
            if d.shape[0] == bi: break
            mm = np.min(d[bi:])
            if mm>=0: break
            # 把最不能滿足不等式約束的去掉
            valid[bi:][valid[bi:]][np.argmin(d[bi:])] = False
            # cnt=0
            # mm = 0
            # sel = None
            # for i in range(bi, bi+n_neq):
                # if valid[i]:
                    # if d[bi+cnt]<mm: 
                        # mm = d[bi+cnt]
                        # sel = i
                    # cnt+=1
            # if mm == 0: break
            # valid[i] = False # 最小的那個去掉
        else:
            # step 1:保證更新之後,未參與的約束合法
            alpha = 1
            for i in range(bi, bi+n_neq):
                if not valid[i]:
                    # 對於不在可行解之中的
                    fm = Q[i, :n_x].dot(d[:n_x])
                    if fm<0: # 可能會出了可行域, 因爲不等式是 >= b, 所以如果fm小於0則可能 出現<b
                        # a^T(x+\alpah d) = b -> \alpha = (b-a^Tx)/
                        alpha = min(alpha, (q[i]-Q[i, :n_x].dot(x_init))/fm)
                        assert alpha!=0

            x_init += alpha*d[:n_x]
            
            # step2: 發現有新的有效約束則加入
            for i in range(bi, bi+n_neq):
                if not valid[i]:
                    if np.abs(q[i]-Q[i, :n_x].dot(x_init))<1e-9:
                        valid[i] = True
            
            #p("alpha", alpha, d)
            
    p(x_init) 
    return x_init

發佈了107 篇原創文章 · 獲贊 34 · 訪問量 20萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章