C++实现python标准库中的Counter

看python standard library by exmple里面提到一个Counter容器,它像muliset一样,能够维持一个集合,并在常量时间插入元素、查询某个元素的个数,而且还提供了一个

most_common(n)方法,用于统计频数最大的n个元素,这在读取文本并统计词频的时候显得非常实用。

考虑C++实现的时候,查到一个叫做LFU的东西,https://en.wikipedia.org/wiki/Least_frequently_used,是关于磁盘缓存策略的,基本想法跟这个counter有类似的地方。

http://dhruvbird.com/lfu.pdf 这里有相关的实现。

#include<iostream>
#include<list>
#include<vector>
#include<unordered_map>
using namespace std;
//关键字节点
template<typename T>
struct keyNode{
        typedef T value_type;
        keyNode(){}
        keyNode(T v, keyNode* p, keyNode* n) :val(v), prev(p), next(n){}
        T val;
        keyNode* prev;
        keyNode* next;
};
//计数器节点
template<typename T>
struct countNode{
        countNode(){
                keyhead = new keyNode<T> ;
                keyhead->prev = keyhead->next = NULL;
        }
        ~countNode(){
                while (keyhead->next != NULL){
                        keyNode<T>* p = keyhead->next;
                        keyhead->next = p->next;
                        delete p;
                }
                delete keyhead;
        }
        countNode(int f, countNode* p, countNode *n):
                freq(f),prev(p),next(n){
                keyhead = new keyNode<T>;
                keyhead->prev = keyhead->next = NULL;
        }
        keyNode<T>* insertKey(const T& v){
                keyNode<T>* node = new keyNode<T>(v, keyhead, keyhead->next);
                if (keyhead->next != NULL)
                        keyhead->next->prev = node;
                keyhead->next = node;
                return node;
        }
        int freq;
        keyNode<T>* keyhead;
        countNode* prev;
        countNode* next;
};

//计数器容器
/***支持如下操作:
        插入(insert) 时间复杂度O(1)
        查找(lookup) 时间复杂度O(1)
        查询最频繁的n个元素(most_common(n)) 时间复杂度o(n)
        删除操作 时间复杂度o(1)
**/
template<typename T>
class Counter{
public:
        Counter(){
                head = new countNode<T>(0, NULL, NULL);
                tail = NULL;
        }
        ~Counter(){
                while (head->next != NULL){
                        countNode<T>* p = head->next;
                        head->next = p->next;
                        delete p;
                }
                delete head;
        }
        //插入一个关键字,如果已经存在,频数加1
        void insert(const T& v){
                if (dict.find(v) == dict.end()){
                        //关键字是新插入的
                        if (head->next == NULL || head->next->freq != 1){
                                //需要新建count节点
                                countNode<T>* node = new  countNode<T>(1, head, head->next);
                                if (head->next == NULL)
                                        tail = node;
                                head->next = node;
                                dict[v] = pair<countNode<T>*, keyNode<T>*>(node, node->insertKey(v));
                        }
                        else{
                                dict[v] = 
                                        pair<countNode<T>*, keyNode<T>*>(head->next, head->next->insertKey(v));
                        }
                }
                else{
                        //关键字已经存在了      
                        //频数必然会有增加,这时对结构的改动较大
                        countNode<T>* countAddr = dict[v].first;
                        countNode<T>* nextCount = countAddr->next; 
                        keyNode<T>* keyAddr = dict[v].second;
                        int freq = countAddr->freq;
                        //首先从countAddr删除一个keyAddr节点
                        keyAddr->prev->next = keyAddr->next;
                        if (keyAddr->next != NULL)
                                keyAddr->next->prev = keyAddr->prev;
                        delete keyAddr;
                        if (nextCount == NULL || nextCount->freq != freq + 1){
                                //需要加一个countNode节点
                                countNode<T>* node = new countNode<T>(freq + 1, countAddr, nextCount);
                                if (nextCount != NULL)
                                        nextCount->prev = node;
                                else
                                        tail = node;
                                countAddr->next = node;
                                dict[v] = 
                                        pair<countNode<T>*, keyNode<T>*>(node, node->insertKey(v));

                        }
                        else{
                                dict[v] = 
                                        pair<countNode<T>*, keyNode<T>*>(nextCount, nextCount->insertKey(v));
                        }
                        //如果删除的keyNode节点是countNode中最后一个keyNode,就要把countAddr也删除了
                        if (countAddr->keyhead->next == NULL){
                                countAddr->prev->next = countAddr->next;
                                if (countAddr->next != NULL)
                                        countAddr->next->prev = countAddr->prev;
                                delete countAddr;
                        }
                }
        }
        //返回关键字的频数
        int lookup(const T& v)const{
                return dict[v].first->freq;
        }
        /**返回频数最高的n个元素
         返回形式为:(key,count)
        **/
        vector<pair<T, int>> most_common(int n){
                //链表的顺序是频数从低到高的,此时需要从尾节点逆向遍历n个元素
                vector<pair<T, int>> result;
                countNode<T>* countVisitor = tail;
                while (n > 0 && countVisitor != NULL){
                        keyNode<T>* keyVisitor = countVisitor->keyhead->next;
                        while (n > 0 && keyVisitor != NULL){
                                result.emplace_back(keyVisitor->val, countVisitor->freq);
                                n--;
                                keyVisitor = keyVisitor->next;
                        }
                        countVisitor = countVisitor->prev;
                }
                return result;
        }
        vector<pair<T, int>> least_common(int n){
                vector<pair<T, int>> result;
                countNode<T>* countVisitor = head->next;
                while (n > 0 && countVisitor !=  NULL){
                        keyNode<T>* keyVisitor = countVisitor->keyhead->next;
                        while (n > 0 && keyVisitor != NULL){
                                result.emplace_back(keyVisitor->val, countVisitor->freq);
                                n--;
                                keyVisitor = keyVisitor->next;
                        }
                        countVisitor = countVisitor->next;
                }
                return result;
        }
private:
        countNode<T>* head;
        countNode<T>* tail;
        unordered_map<T, pair<countNode<T>*, keyNode<T>*>> dict;
};
int main(){
        {
                Counter<char> wordCount;
                string s("jfoaedfrerlkmgvj9ejajiokl;fdaks");
                for (auto v : s){
                        wordCount.insert(v);
                }
                auto result = wordCount.least_common(3);
        }
        return 0;
}