[模版] K/最近鄰匹配(KD Tree)

nth_element相當於快排的split/choose pivot過程

2D 最近

#include <bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
const int INF=0x3f3f3f3f;
const int maxn=1e5+50;
typedef long long ll;
using namespace std;

int cmpNo;
struct Node{
    int x[2],l,r,id;
    bool operator <(const Node &b)const{
        return x[cmpNo]<b.x[cmpNo];
    }
};

ll calDis(Node &l,Node &r){
    ll dx=l.x[0]-r.x[0],dy=l.x[1]-r.x[1];
    return dx*dx+dy*dy;
}

Node p[maxn];

int Build(int l,int r,int d){
    if(l>r)return 0;
    cmpNo=d;
    int mid=l+r>>1;
    nth_element(p+l,p+mid,p+r+1);
    p[mid].l=Build(l,mid-1,1-d);
    p[mid].r=Build(mid+1,r,1-d);
    return mid;
}

ll ansDist;
int ansId;
void Kth(int l,int r,Node &tar,int d){
    if(l>r)return;
    int mid=l+r>>1;
    if(p[mid].id!=tar.id){
        ll tmp=calDis(p[mid],tar);
        if(tmp<ansDist){
            ansDist=tmp;
            ansId=p[mid].id;
        }
    }
    long long t=tar.x[d]-p[mid].x[d];
    if(t<=0){
        Kth(l,mid-1,tar,1-d);
        if(ansDist>t*t)
            Kth(mid+1,r,tar,1-d);
    }
    else{
        Kth(mid+1,r,tar,1-d);
        if(ansDist>t*t)
            Kth(l,mid-1,tar,1-d);
    }
}

Node a[maxn];

int main(){
    int T; scanf("%d",&T);
    for(int cs=1;cs<=T;cs++) {
        int n;  scanf("%d",&n);
        for(int i=1;i<=n;i++){
            scanf("%d %d",&p[i].x[0],&p[i].x[1]);
            p[i].id=i;
            a[i]=p[i];
        }
        Build(1,n,0);
        for(int i=1;i<=n;i++) {
            ansDist=1e18;
            Kth(1,n,a[i],0);
            printf("%lld\n",calDis(a[i],a[ansId]));
        }
    }
    return 0;
}

K維 Q近

#include <bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
const int INF=0x3f3f3f3f;
const int maxn=4e5+50;
typedef long long ll;
using namespace std;
int cmpNo;
const int K=2;

struct Node{
    int x[K],l,r,id;
    bool operator <(const Node &b)const{
        return x[cmpNo]<b.x[cmpNo];
    }
};

long long Dis(const Node &a,const Node &b){
    long long ret=0;
    for(int i=0;i<K;i++)
        ret+=(a.x[i]-b.x[i])*(a.x[i]-b.x[i]);
    return ret;
}

Node p[maxn];

int Build(int l,int r,int d){
    if(l>r)return 0;
    cmpNo=d;
    int mid=l+r>>1;
    nth_element(p+l,p+mid,p+r+1);
    p[mid].l=Build(l,mid-1,(d+1)%K);
    p[mid].r=Build(mid+1,r,(d+1)%K);
    return mid;
}

priority_queue<pair<ll,int> >q;
void Kth(int l,int r,Node tar,int k,int d){
    if(l>r)return;
    int mid=l+r>>1;
    if(p[mid].id!=tar.id){
        pair<ll,int>v=make_pair(Dis(p[mid],tar),p[mid].id);
        if(q.size()==k && v<q.top())q.pop();
        if(q.size()<k)q.push(v);
    }
    ll t=tar.x[d]-p[mid].x[d];
    if(t<=0){
        Kth(l,mid-1,tar,k,(d+1)%K);
        if(q.top().first>t*t)
            Kth(mid+1,r,tar,k,(d+1)%K);
    }
    else{
        Kth(mid+1,r,tar,k,(d+1)%K);
        if(q.top().first>t*t)
            Kth(l,mid-1,tar,k,(d+1)%K);
    }
}

Node a[maxn];

ll calDis(Node &l,Node &r){
    ll dx=l.x[0]-r.x[0],dy=l.x[1]-r.x[1];
    return dx*dx+dy*dy;
}

const int Q=2;
int main(){
    int T; scanf("%d",&T);
    for(int cs=1;cs<=T;cs++) {
        int n;  scanf("%d",&n);
        for(int i=1;i<=n;i++){
            scanf("%d %d",&p[i].x[0],&p[i].x[1]);
            p[i].id=i;
            a[i]=p[i];
        }
        Build(1,n,0);
        for(int i=1;i<=n;i++) {
            while(!q.empty()) q.pop();
            for(int j=0;j<Q;j++) q.push(make_pair(1e18,-1));
            Kth(1,n,a[i],Q,0);
            while(!q.empty()){
                printf("%lld\n",calDis(a[i],a[q.top().second]));
                q.pop();
            }
        }
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章