plda源碼(十二)

plda源碼(十二)

LightLDA

原始 Gibbs Sampling 採樣函數如下:
p(zdi=krest)(nkddi+αk)(nkwdi+βw)nkdi+βp(z_{di}=k | rest) ∝ \frac{(n^{−di}_{kd}+\alpha_k)(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}}

AliasLDA
p(zdi=krest)nkddi(nkwdi+βw)nkdi+β+αk(nkwdi+βw)nkdi+βp(z_{di}=k | rest) ∝ \frac{n^{−di}_{kd}(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}} + \frac{\alpha_k(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}}
第二項可以看做“topic-word”桶,與文檔無關。這一項可以通過Alias Table和 Metropolis-Hastings(一種蒙特卡洛採樣方法) 進行O(1) 時間複雜度採樣。Alias Table在上一篇文章有介紹。

LightLDA
p(zdi=krest)(nkddi+αk)(nkwdi+βw)nkdi+βp(z_{di}=k | rest) ∝ (n^{−di}_{kd}+\alpha_k) * \frac{(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}}
q(zdi=krest)(nkd+αk)nkw+βwnk+βq(z_{di}=k | rest) \propto (n_{kd} + \alpha_{k}) * \frac{n_{kw} + \beta_w}{n_k + \overline{\beta}}
第一項爲doc-proposal,第二項爲word-proposal。
同樣退化成MH採樣
min{1,p(t)q(ts)p(s)q(st)}min\{ 1, \frac{p(t)q(t \rightarrow s)}{p(s)q(s\rightarrow t)} \}

doc-proposal
q=pd(k)nkd+αkq = p_d(k) \propto n_{kd}+\alpha_k
接受率
πd=(ntddi+αt)(ntwdi+βw)(nsdi+β)(nsd+αs)(nsddi+αs)(nswdi+βw)(ntdi+β)(ntd+αt)\pi_d = \frac{ (n^{−di}_{td}+\alpha_t)(n^{−di}_{tw}+\beta_w)(n^{−di}_s+\overline{\beta})(n_{sd}+\alpha_s)}{ (n^{−di}_{sd}+\alpha_s)(n^{−di}_{sw}+\beta_w)(n^{−di}_t+\overline{\beta})(n_{td}+\alpha_t)}

        int K = model_->num_topics();
        double sumPd = document->GetDocumentLength() + Kalpha;
        for (...) {
            int w = iterator.Word();
            int topic = iterator.Topic();
            int new_topic;
            int old_topic = topic;
            
                {
                    // Draw a topic from doc-proposal
                    double u = random->RandDouble() * sumPd;
                    if (u < document->GetDocumentLength()) {
                        // draw from doc-topic distribution skipping n
                        unsigned pos = (unsigned) (u);
                        new_topic = document->topics().wordtopics(pos);
                    } else {
                        // draw uniformly
                        u -= document->GetDocumentLength();
                        u = u / alpha_;
                        new_topic = (unsigned short) (u); // pick_a_number(0,trngdata->docs[m]->length-1); (int)(utils::unif01()*ptrndata->docs[m]->length);
                    }

                    if (topic != new_topic) {
                        //2. Find acceptance probability
                        int ajustment_old = topic == old_topic? -1 : 0;
                        int ajustment_new = new_topic == old_topic? -1 : 0;
                        double temp_old = ComputeProbForK(document, w, topic, ajustment_old);
                        double temp_new = ComputeProbForK(document, w, new_topic, ajustment_new);
                        double prop_old = (N_DK(document, topic) + alpha_);
                        double prop_new =  (N_DK(document, new_topic) + alpha_);
                        double acceptance = (temp_new * prop_old) / (temp_old * prop_new);

                        //3. Compare against uniform[0,1]
                        if (random->RandDouble() < acceptance) {
                            topic = new_topic;
                        }
                    }

其中的ComputeProbForK是

double ComputeProbForK(LDADocument* document, int w, int topic,
            int ajustment) {
      return  (N_DK(document, topic) + alpha_ + ajustment)
                * (N_WK(w, topic) + beta_ + ajustment)
                / (N_K(topic) + Vbeta + ajustment);
  }

word-proposal
q=pd(k)nkw+βwnk+βq = p_d(k) \propto \frac{n_{kw} + \beta_w}{n_k + \overline{\beta}}
接受率
πw=(ntddi+αt)(ntwdi+βw)(nsdi+β)(nsw+βw)(nt+β)(nsddi+αs)(nswdi+βw)(ntdi+β)(ntw+βw)(ns+β)\pi_w = \frac{ (n^{−di}_{td}+\alpha_t)(n^{−di}_{tw}+\beta_w)(n^{−di}_s+\overline{\beta})(n_{sw} + \beta_w)(n_t + \overline{\beta})}{ (n^{−di}_{sd}+\alpha_s)(n^{−di}_{sw}+\beta_w)(n^{−di}_t+\overline{\beta})(n_{tw} + \beta_w)(n_s + \overline{\beta})}

{
                    // Draw a topic from word-proposal
                    q[w].noSamples++;
                    if (q[w].noSamples > qtable_construct_frequency) {
                        GenerateQTable(w);
                    }
                    new_topic = q[w].sample(random->RandInt(K), random->RandDouble());
                    if (topic != new_topic) {
                        //2. Find acceptance probability
                        int ajustment_old = topic == old_topic? -1 : 0;
                        int ajustment_new = new_topic == old_topic? -1 : 0;
                        double temp_old = ComputeProbForK(document, w, topic, ajustment_old);
                        double temp_new = ComputeProbForK(document, w, new_topic, ajustment_new);
                        double acceptance = (temp_new * q[w].w[topic]) / (temp_old * q[w].w[new_topic]);

                        //3. Compare against uniform[0,1]
                        if (random->RandDouble() < acceptance) {
                            topic = new_topic;
                        }
                    }
                }

其中GenerateQTable如下

    void GenerateQTable(unsigned int w) {
        int num_topics = model_->num_topics();
        q[w].wsum = 0.0;
        const TopicDistribution<int32>& word_distribution = model_->GetWordTopicDistribution(w);
        const TopicDistribution<int32>& n_k = model_->GetGlobalTopicDistribution();
        for (int k = 0; k < num_topics; ++k) {
            q[w].w[k] =   (word_distribution[k] + beta_) / (n_k[k] + Vbeta);
            q[w].wsum += q[w].w[k];
        }
        q[w].constructTable();
    }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章