引言
對於值來說treap是排序二叉樹
對於優先級來說treap是堆
treap的優先級由rand()來保證複雜度
正文
treap的常用操作(名次樹):
1:插入x數
2:刪除x數(若有多個相同的數,因只刪除一個)
3:查詢x數的排名(若有多個相同的數,因輸出最小的排名)
4:查詢排名爲x的數
5:求x的前驅(前驅定義爲小於x,且最大的數)
6:求x的後繼(後繼定義爲大於x,且最小的數)
節點
ch[0]是左子樹
ch[1]是右子樹
s是子樹元素大小
n是值的重複次數
r:rand()
v:值
struct node{
node* ch[2];
int r,v,s,n;
int cmp(int x){
if(x==v)return -1;
return x>v;
}
void maintain(){
s=ch[0]->s+ch[1]->s+n;
}
};
旋轉
d==0時左旋
d==1時右旋
其中的指針o是引用,可以保證指向根節點
void Rotate(node* &o,int d){
node* k=o->ch[d^1]; o->ch[d^1]= k->ch[d]; k->ch[d]=o;
o->maintain(); k->maintain(); o=k;
}
插入
插入時可能會破壞堆的性質
所以要旋轉(從下到上旋轉到根節點)
void Insert(node* &o,int x){
if(o==null){
o=new node();
o->ch[0]=o->ch[1]=null;
o->v=x;
o->r=rand();
o->n=1;
}
else{
int d=o->cmp(x);
if(d==-1) o->n++;
else {
Insert(o->ch[d],x);
if(o->ch[d]->r > o->r)Rotate(o,d^1);
}
}
o->maintain();
}
刪除
刪除時要判斷左右子樹的優先級關係,再旋轉
(從上到下旋轉,旋轉到葉子節點)
void Remove(node* &o,int x){
if(o==null)return;
int d=o->cmp(x);
if(d==-1){
if(o->n > 1) (o->n)--;
else{
if(o->ch[0]==null)o=o->ch[1];
else if(o->ch[1]==null)o=o->ch[0];
else {
int e= ((o->ch[0]->r) > (o->ch[1]->r));
Rotate(o,e);
Remove(o->ch[e],x);
}
}
}
else if(o->ch[d]==null)return;
else Remove(o->ch[d],x);
if(o!=null)o->maintain();
}
細節
node *null=new node();//這裏的null可以比NULL(空指針)更好防止越界
int Query(node* o,int x,int e)//求前驅和求後繼 合併
void Print(node* o)//中途輸出查錯
心得
調了好久~~
其實treap主要是弄清指針ch[]和旋轉的用法(雖然非指針和非旋轉的treap也有)
一遍treap下來有種豁然貫通的感覺
附上我的造數據代碼(有很小很小的概率數據是不嚴謹的)
如果有人能幫我在造數據代碼中寫個最小&&大值線段樹就更完美了
另外有一個小小的缺點,沒有find()函數,不能預先判斷刪除值是否存在
完整代碼
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<ctime>
#define INF 2147483647
using namespace std;
int n,p,q,ans;
struct node{
node* ch[2];
int r,v,s,n;
int cmp(int x){
if(x==v)return -1;
return x>v;
}
void maintain(){
s=ch[0]->s+ch[1]->s+n;
}
};
node *null=new node();
node* ro=null;
inline int rd(){
char ch;
int p=0,q=1;
while((ch=getchar())<'0'||ch>'9')if(ch=='-')q=-1;
while(ch>='0'&&ch<='9')p=p*10+ch-'0',ch=getchar();
return p*q;
}
void Rotate(node* &o,int d){
node* k=o->ch[d^1]; o->ch[d^1]= k->ch[d]; k->ch[d]=o;
o->maintain(); k->maintain(); o=k;
}
void Insert(node* &o,int x){
if(o==null){
o=new node();
o->ch[0]=o->ch[1]=null;
o->v=x;
o->r=rand();
o->n=1;
}
else{
int d=o->cmp(x);
if(d==-1) o->n++;
else {
Insert(o->ch[d],x);
if(o->ch[d]->r > o->r)Rotate(o,d^1);
}
}
o->maintain();
}
void Remove(node* &o,int x){
if(o==null)return;
int d=o->cmp(x);
if(d==-1){
if(o->n > 1) (o->n)--;
else{
if(o->ch[0]==null)o=o->ch[1];
else if(o->ch[1]==null)o=o->ch[0];
else {
int e= ((o->ch[0]->r) > (o->ch[1]->r));
Rotate(o,e);
Remove(o->ch[e],x);
}
}
}
else if(o->ch[d]==null)return;
else Remove(o->ch[d],x);
if(o!=null)o->maintain();
}
void Kth(node* o,int x){
if(o==null)return;
if(o->ch[0]->s + o->n >= x){
if(o->ch[0]->s < x)printf("%d\n",o->v);
else Kth(o->ch[0],x);
}
else Kth(o->ch[1],x-(o->ch[0]->s + o->n));
}
int Rank(node* o,int x){
if(o==null)return -INF;
int d=o->cmp(x);
if(d==-1)return o->ch[0]->s+1;
if(d==1)return Rank(o->ch[1],x)+o->ch[0]->s+o->n;
return Rank(o->ch[0],x);
}
int Query(node* o,int x,int e){
if(o==null)return ans;
int d=o->cmp(x);
if(d==-1)return Query(o->ch[e],x,e);
if((o->cmp(x))==(e^1))ans=o->v;
return Query(o->ch[d],x,e);
}
void Print(node* o){
if(o!=null){
// printf("%d %d %d %d %d %d %d %d %d\n",o->v,o->r,o->s,o->ch[0]->v,o->ch[0]->r,o->ch[0]->s,o->ch[1]->v,o->ch[1]->r,o->ch[1]->s);
Print(o->ch[0]);
Print(o->ch[1]);
}
}
int main(){
freopen("data.txt","r",stdin);
freopen("1.txt","w",stdout);
srand((unsigned)time(NULL));
n=rd();
for(int i=1;i<=n;i++){
p=rd(); q=rd();
if(p==1)Insert(ro,q);
else if(p==2)Remove(ro,q);
else if(p==3)printf("%d\n",Rank(ro,q));
else if(p==4)Kth(ro,q);
else if(p==5){
ans=-INF; printf("%d\n",Query(ro,q,0));
}
else if(p==6){
ans=-INF; printf("%d\n",Query(ro,q,1));
}
// printf("\n%d %d\n",p,q);
// Print(ro);
}
}
造數據代碼
#include<cstdio>
#include<iostream>
#include<ctime>
#include<cstdlib>
#define MOD 100000
#define INF 2147483647
using namespace std;
int n,p,q,sum,ma,mi;
int b[MOD*2];
int main(){
freopen("data.txt","w",stdout);
mi=INF;
srand((unsigned)time(NULL));
n=rand()%MOD+1;
printf("%d\n",n);
for(int i=1;i<=n/2;i++){
q=rand()%MOD+1;
ma=max(ma,q);
mi=min(mi,q);
sum++;
printf("%d %d\n",1,q);
b[q]++;
}
for(int i=n/2+1;i<=n;i++){
p=rand()%6+1;
q=rand()%MOD+1;
while((p==5&&q<=mi)||(p==6&&q>=ma)||((p==2||p==3)&&b[q]==0)||((p==4)&&(q>sum))){
p=rand()%6+1;
q=rand()%MOD+1;
}
if(p==1){
b[q]++;
sum++;
ma=max(ma,q);
mi=min(mi,q);
}
if(p==2){
sum--;
b[q]--;
}
printf("%d %d\n",p,q);
}
}