transformer xl---vocabulary

data_dir 存放原始數據,

def main(unused_argv):
    del unused_argv  # Unused

    corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)  #

    save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
    if not tf.gfile.Exists(save_dir):
        tf.gfile.MakeDirs(save_dir)

    # test mode
    if FLAGS.per_host_test_bsz > 0:
        corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
                                    FLAGS.tgt_len, FLAGS.num_core_per_host,
                                    FLAGS=FLAGS)
        return

    for split, batch_size in zip(
            ["train", "valid"],
            [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):

        if batch_size <= 0: continue
        print("Converting {} set...".format(split))
        corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
                                    FLAGS.num_core_per_host, FLAGS=FLAGS)

讀取字典,字典會使用pickle序列化存儲在磁盤中。初次獲取字典時,會創建
創建Corpus主要有四步:

1、count_file,讀取原文中每一行內容,去除首尾的空格和換行\n,然後逐字拆分爲數組,數組中添加< eos >標記,統計每一個詞的出現次數記錄在counter = Counter(),
2、使用build_vocab創建詞彙表,把統計的所有詞根據asic編碼排序,去除低頻詞彙
3、add_symbol,原始符號與索引的映射–sym2idx,索引到原始詞綴的映射idx2sym(按照順序,數組下標既是索引)

def get_lm_corpus(data_dir, dataset):
    fn = os.path.join(data_dir, "cache.pkl")

    if tf.gfile.Exists(fn):
        print("Loading cached dataset...")
        with open(fn, "rb") as fp:
            corpus = pickle.load(fp)
    else:
        print("Producing dataset...")
        kwargs = {}

        kwargs["special"] = ["<eos>"]
        kwargs["lower_case"] = False

        corpus = Corpus(data_dir, dataset, **kwargs)

        print("Saving dataset...")
        with open(fn, "wb") as fp:
            pickle.dump(corpus, fp, protocol=2)

        corpus_info = {
            "vocab_size": len(corpus.vocab),
            "cutoffs": corpus.cutoffs,
            "dataset": corpus.dataset
        }
        with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
            json.dump(corpus_info, fp)

    return corpus
class Vocab(object):
    def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
                 delimiter=None, vocab_file=None):
        self.counter = Counter()
        self.special = special
        self.min_freq = min_freq
        self.max_size = max_size
        self.lower_case = lower_case
        self.delimiter = delimiter
        self.vocab_file = vocab_file
        self.idx2sym = []
        self.sym2idx = OrderedDict()           # todo  確定這裏有沒有問題

        # for zhihu dataset
        # todo delete here when test other datasets
        # self.min_freq = 100
        # self.add_symbol('<UNK>')
        # self.unk_idx = self.get_idx('<UNK>')

    def tokenize(self, line, add_eos=False, add_double_eos=False):
        line = line.strip()
        symbols = list(line)

        if add_double_eos:  # lm1b
            # 確保 在symbol list 中能找
            self.add_symbol('<S>')
            return ['<S>'] + symbols + ['<S>']
        elif add_eos:
            return symbols + ['<eos>']
        else:
            return symbols

    # 取出file 中的sentences
    def count_file(self, path, verbose=False, add_eos=False):
        if verbose: print('counting file {} ...'.format(path))
        assert tf.gfile.Exists(path)

        sents = []
        with open(path, 'r',encoding='UTF-8') as f:
        # 讀取每一行的內容
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line, add_eos=True)
                self.counter.update(symbols)
                sents.append(symbols)

        return sents

    # 更新counter 中的token
    def count_sents(self, sents, verbose=False):
        """
          sents : a list of sentences, each a list of tokenized symbols
        """
        if verbose: print('counting {} sents ...'.format(len(sents)))
        for idx, symbols in enumerate(sents):
            if verbose and idx > 0 and idx % 500000 == 0:
                print('  line {}'.format(idx))
            self.counter.update(symbols)

    def _build_from_file(self, vocab_file):
        # self.idx2sym = []
        # self.sym2idx = OrderedDict()

        with open(vocab_file, 'r') as f:
            for line in f:
                symb = line.strip().split()[0]
                self.add_symbol(symb)
        self.unk_idx = self.sym2idx['<UNK>']

    # 建立vocab, 將symbol 保存
    def build_vocab(self):
        if self.vocab_file:
            print('building vocab from {}'.format(self.vocab_file))
            self._build_from_file(self.vocab_file)
            print('final vocab size {}'.format(len(self)))
        else:
            print('building vocab with min_freq={}, max_size={}'.format(
                self.min_freq, self.max_size))

            self.add_special("<eos>")

            # todo 這裏巨坑!!!!!
            # for sym, cnt in self.counter.most_common(self.max_size):
            #     if cnt < self.min_freq:
            #         break
            tmp = sorted(self.counter.items(), key=lambda item:item[0])
            for sym, cnt in tmp:
                if cnt < self.min_freq:
                    continue
                self.add_symbol(sym)
            print('final vocab size {} from {} unique tokens'.format(
                len(self), len(self.counter)))

    # 主要在於convert_to_nparray, 其實也就是將vocab變成idx
    def encode_file(self, path, ordered=False, verbose=False,
                    add_double_eos=False):
        if verbose: print('encoding file {} ...'.format(path))
        assert tf.gfile.Exists(path)
        encoded = []
        with open(path, 'r',encoding="utf-8") as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line, add_eos=True, add_double_eos=add_double_eos)

                encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded

    #
    def encode_sents(self, sents, ordered=False, verbose=False):
        if verbose: print('encoding {} sents ...'.format(len(sents)))
        encoded = []

        symbols = self.tokenize(sents)
        encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded

    def add_special(self, sym):
        if sym not in self.sym2idx:
            self.idx2sym.append(sym)
            self.sym2idx[sym] = len(self.idx2sym) - 1
            setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])

    def add_symbol(self, sym):
        if sym not in self.sym2idx:
            self.idx2sym.append(sym)
            self.sym2idx[sym] = len(self.idx2sym) - 1

    def get_sym(self, idx):
        assert 0 <= idx < len(self.idx2sym), 'Index {} out of range'.format(idx)
        return self.idx2sym[idx]

    def get_idx(self, sym):
        if sym in self.sym2idx:
            return self.sym2idx[sym]
        else:
            assert hasattr(self, 'unk_idx')
            return self.sym2idx.get(sym, self.unk_idx)

    def get_symbols(self, indices):
        return [self.get_sym(idx) for idx in indices]

    def get_indices(self, symbols):
        return [self.get_idx(sym) for sym in symbols]

    # 字轉index
    def convert_to_nparray(self, symbols):
        nparray = np.array(self.get_indices(symbols), dtype=np.int64)
        return nparray

    # index轉字
    def convert_to_sent(self, indices, exclude=None):
        if exclude is None:
            return ' '.join([self.get_sym(idx) for idx in indices])
        else:
            return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])

    def __len__(self):
        return len(self.idx2sym)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章