fasttext源碼學習(1)–dictionary
前言
fasttext在文本分類方面很厲害,精度高,速度快,模型小(壓縮後),總之非常值得學習。花了點時間學習了下源碼,本篇主要是與dictionary相關。
dictionary主要存儲詞語和切分詞及對應的id,因爲fasttext能處理超大數據集,如果不使用一些方法,只是加載這些內容,內存就很容易爆掉,我們來看看有哪些關鍵方法。
一 詞語數量控制
該方法在Dictionary::readFromFile中調用,截取關鍵部分如下:
// 簡化版
void Dictionary::readFromFile(std::istream& in) {
int64_t minThreshold = 1; // [1]
while (readWord(in, word)) {
add(word);
if (size_ > 0.75 * MAX_VOCAB_SIZE) { // [2]
minThreshold++;
threshold(minThreshold, minThreshold);
}
}
threshold(args_->minCount, args_->minCountLabel); // [3]
initTableDiscard();
initNgrams();
}
其中threshold函數爲一個過濾方法,即按出現頻次過濾,小於閾值的被刪除,邏輯並不複雜。
// t: 詞語頻次閾值
// t1: label頻次閾值
void Dictionary::threshold(int64_t t, int64_t tl)
這裏用到幾個控制方法:
- 只存儲詞語和標籤
while循環裏的readWord(in, word)函數並沒有做單詞的切分,只讀取了空格等分割的詞語;而add函數內部做了去重,相同的詞語只增加了計數。
2.自動過濾
從[1], [2]中可以看出,超出閾值就會自動進行一次過濾, 觸發自動過濾的條件是大於 MAX_VOCAB_SIZE的3/4,MAX_VOCAB_SIZE值爲3000萬,一般情況下,只存儲去重後的詞語,還是很難觸發自動過濾的。
3. 參數控制過濾
從[3]可以看出文件讀取完成後,會再次調用threshold,控制參數是minCount,minCountLabel, 可以從控制檯輸入進行控制,這樣即使未觸發自動過濾,也會根據用戶需要進行詞語頻次過濾。
使用以上方法可以根據實際需要,將用於存儲(包括存儲至模型中)的詞語數量控制在一定範圍內;其中最主要的是隻存儲詞語,因爲如果加上切詞,可能會使詞表迅速膨脹,那切詞的信息是如何處理的?
二 ngrams處理
從前面可以看出readFromFile函數是在詞語加載完畢之後,才調用了initNgrams函數來初始化ngrams信息,從這個調用邏輯也可以看出dictionary的處理思路,ngrams和words分開處理。
// 簡化版
void Dictionary::initNgrams() {
for (size_t i = 0; i < size_; i++) {
std::string word = BOW + words_[i].word + EOW;
if (words_[i].word != EOS) {
computeSubwords(word, words_[i].subwords);// [1]
}
}
}
void Dictionary::computeSubwords(std::string& word,std::vector<int32_t>& ngrams) const {
for (size_t i = 0; i < word.size(); i++) {
std::string ngram;
for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
ngram.push_back(word[j++]);
if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
int32_t h = hash(ngram) % args_->bucket;
pushHash(ngrams, h); // [2]
}
}
}
}
從[1]、[2]可以看出,ngram的id是存儲在每個word的結構體中的(words_[i].subwords),加載文件時,會計算出該id,供後續的訓練時候使用,但是該id並不會存儲到模型中,那如何保證每次計算都得出正確的id呢?因爲訓練中的矩陣計算可不會區分words和ngrams,它們都在一個矩陣中。
三 hash與id
1 words的hash與id
從第一部分可以知道words的初始化是在add函數中,下面列出相關函數:
std::vector<int32_t> word2int_;
std::vector<entry> words_;
void Dictionary::add(const std::string& w) {
int32_t h = find(w);
ntokens_++;
if (word2int_[h] == -1) {
entry e;
e.word = w;
e.count = 1;
e.type = getType(w);
words_.push_back(e); // [1]
word2int_[h] = size_++; // [2]
} else {
words_[word2int_[h]].count++;
}
}
從[1] [2]以及定義可以看出,words_直接push進新的詞,words2int_存儲hash值與id的對應關係,而新詞的id爲在數組中的下標,而hash值通過find得出:
// 簡化版
int32_t Dictionary::find(const std::string& w) const {
return find(w, hash(w));
}
int32_t Dictionary::find(const std::string& w, uint32_t h) const {
int32_t word2intsize = word2int_.size();
int32_t id = h % word2intsize; // [1]
while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
id = (id + 1) % word2intsize; // [2]
}
return id;
}
其中[1]保證hash值不超過words2int_的大小(不越界), 而[2]是防碰撞策略,線性增加。
2 ngrams的hash與id
與ngrams有關的函數在computeSubwords中:
int32_t h = hash(ngram) % args_->bucket; // [1]
pushHash(ngrams, h);
void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
hashes.push_back(nwords_ + id); // [2]
}
從[1]處得到id值,[2]中存儲前將id加上了nwords_, 而nwords_是詞語總數,這時候整個邏輯就很清晰了,words的id在nwords_內,而ngrams的id則會始終大於nwords_, 從而保證id不衝突。但是[1]處的args_->bucket又是怎麼回事呢?
四 id與矩陣
訓練時,會調用createRandomMatrix函數創建輸入矩陣:
Matrix FastText::createRandomMatrix() const {
DenseMatrix input = DenseMatrix(dict_->nwords() + args_->bucket, args_->dim); // [1]
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
return input;
}
從[1]處可以看出, 輸入矩陣的行數爲nwords_ + args_->bucket, 結合上面的id的計算,可以看出整個全貌,words的id在前nwords_中;而ngrams的id散列在args_bucket中,加上nwords_的偏移.
而原始的哈希函數反而存在感是最低的:
uint32_t Dictionary::hash(const std::string& str) const {
uint32_t h = 2166136261;
for (size_t i = 0; i < str.size(); i++) {
h = h ^ uint32_t(int8_t(str[i]));
h = h * 16777619;
}
return h;
}
整個過程,hash是輔助手段,保證同樣的內容得到同樣的hash值,而words和ngrams的id分割則是後續矩陣運算的關鍵。
總結
從源碼閱讀的過程中,可以感覺到fasttext對速度的追求,比如words2int_沒有使用map,而是通過hash+mod的方法,將查詢變爲O(1)的複雜度(近似啦); 而hash的使用,words和ngrams的id的分割,使存儲量大大減小,很巧妙,值得學習。