k維最近點對。
要求輸出前m近點對,用priority_queue存前幾名就好,注意估價函數剪枝
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<queue>
#define fi first
#define se second
#define pii pair<int,int>
#define MK(a,b) make_pair((a),(b))
#define sqr(x) (x)*(x)
using namespace std;
const int N=50005;
inline int read()
{
int ans,f=1;char ch;
while ((ch=getchar())<'0'||ch>'9') if (ch=='-') f=-1;ans=ch-'0';
while ((ch=getchar())>='0'&&ch<='9') ans=ans*10+ch-'0';
return ans*f;
}
int rt,n,k,D,m;
struct aa
{
int l,r;
int d[6],mi[6],mx[6];
}a[N];
void up(int u,int l)
{
for (int i=0;i<k;i++)
a[u].mi[i]=min(a[u].mi[i],a[l].mi[i]),
a[u].mx[i]=max(a[u].mx[i],a[l].mx[i]);
}
bool cmp(aa a,aa b) {return a.d[D]<b.d[D];}
int build(int l,int r,int dd)
{
D=dd;int mid=(l+r)>>1;
nth_element(a+l,a+mid,a+r+1,cmp);
for (int i=0;i<k;i++)
a[mid].mi[i]=a[mid].mx[i]=a[mid].d[i];
if (l!=mid) a[mid].l=build(l,mid-1,(dd+1)%k),up(mid,a[mid].l);else a[mid].l=0;
if (r!=mid) a[mid].r=build(mid+1,r,(dd+1)%k),up(mid,a[mid].r);else a[mid].r=0;
return mid;
}
int x[6];
priority_queue<pii> ans;
int fdis(int u)
{
int ans=0;
for (int i=0;i<k;i++) ans+=sqr(x[i]-a[u].d[i]);
return ans;
}
int f(int u)
{
if (u==0) return 1e9;
int ans=0;
for (int i=0;i<k;i++)
{
if (a[u].mi[i]>x[i]) ans+=sqr(a[u].mi[i]-x[i]);
if (a[u].mx[i]<x[i]) ans+=sqr(x[i]-a[u].mx[i]);
}
return ans;
}
void query(int u)
{
int dis=fdis(u);
int dl=f(a[u].l),dr=f(a[u].r);
if (ans.size()<m) ans.push(MK(dis,u));
else if (ans.top().fi>dis) {ans.pop();ans.push(MK(dis,u));}
if (dl<dr)
{if (ans.size()<m||ans.top().fi>dl) query(a[u].l);if (ans.size()<m||ans.top().fi>dr) query(a[u].r);}
else
{if (ans.size()<m||ans.top().fi>dr) query(a[u].r);if (ans.size()<m||ans.top().fi>dl) query(a[u].l);}
}
int tmp[15],cnt;
void work()
{
for (int i=1;i<=n;i++)
for (int j=0;j<k;j++) a[i].d[j]=read();
rt=build(1,n,0);
int t=read();
while (t--)
{
while (!ans.empty()) ans.pop();
for (int i=0;i<k;i++) x[i]=read();m=read();
query(rt);
printf("the closest %d points are:\n",m);
cnt=0;
while (!ans.empty()) tmp[++cnt]=ans.top().se,ans.pop();
for (int i=cnt;i>=1;i--)
{
for (int j=0;j<k-1;j++) printf("%d ",a[tmp[i]].d[j]);
printf("%d\n",a[tmp[i]].d[k-1]);
}
}
}
int main()
{
while (scanf("%d%d",&n,&k)!=EOF) work();
return 0;
}