基於NETLINK的內核與用戶空間共享內存的實現

 

轉此貼目的只是方便自己參考,感謝作者。

//TODO:
//
// 把內核傳出來的SHM地址傳遞給GUI,然後GUI與內核直接通過SHM通信。
//

================================================

author:bripengandre Email:[email protected]
一、前言
    前些日子,開發中用到了netlink來實現內核與用戶空間共享內存,寫點筆記與大家分享。因爲我對這塊也不瞭解,寫出來的東西一定存在很多錯誤,請大家批評指正~
    內核與用戶空間共享內存的關鍵是,用戶空間必須得知共享內存的起始地址,這就要求內核空間應該有一種通信機制來通知用戶空間。已經有Godbach版主等人 用proc文件系統實現了(可以google '共享內存 內核 用戶空間'),很顯然任何內核空間與用戶空間的通信方法都可資利用。本文主要講基於NETLINK機制的實現。
 
二、NETLINK簡介
    netlink在linux的內核與用戶空間通信中用得很多(但具體例子我舉不出,因爲我不清楚~~請google之),其最大優勢是接口與網絡編程中的socket相似,且內核要主動發信息給用戶空間很方便。
   但通過實踐,我發現netlink通信機制的最大弊病在於其在各內核版本中接口變化太大,讓人難以適從(可從後文列出的源碼中的kernel_receive的聲明窺一斑)。
   既然涉及到內核與用戶空間兩個空間,就應該在兩個空間各有一套接口。用戶空間的接口很簡單,與一般的socket接口相似,內核空間則稍先複雜,但簡單的 應用只需簡單地瞭解即可:首先也是建立描述符,建立描述符時會註冊一個回調函數(源碼中的kernel_receive即是),然後當用戶空間有消息發過 來時,我們的函數將被調用,顯然在這個函數裏我們可做相應的處理;當內核要主動發消息給用戶進程時,直接調用一個類send函數即可 (netlink_unicast系列函數)。當然這個過程中,有很多結構體變量需要填充。具體用法請google,我差不多忘光了~。
 
三、基於netlink的共享內存   
   這裏的共享內存是指內核與用戶空間,而不是通常的用戶進程間的。
   大概流程如下。
   內核:__get_free__pages分配連續的物理內存頁(貌似返回的其實是虛擬地址)-->SetPageReserved每一頁(每一頁 都需這個操作,參見源碼)-->如果用戶空間通過netlink要求獲取共享內存的起始物理地址,將__get_free__pages返回的地址 __pa下發給用戶空間。
   用戶空間:open "/dev/shm"(一個讀寫物理內存的設備,具體請google"linux讀寫物理內存")-->發netlink消息給內核,得到共享內存的起始物理地址-->mmap上步得到的物理地址。
 
四、源碼說明
   正如二中提到的,netlink接口在各版本中接口變化很大,本人懶惰及時間緊,只實驗了比較新的內核2.6.25,源碼如需移植到老版本上,需要一些改動,敬請原諒。
   另外,由於本源碼是從一個比較大的程序裏摳出來的,所以命名什麼的可能有點怪異~
   源碼包括shm_k.c(內核模塊)和用戶空間程序(shm_u.c)兩部分。shm_k.c模塊的工作是:分配8KB內存供共享,並把前幾十個字節置爲 “hello, use share memory with netlink"字樣。shm_u.c工作是:讀取共享內存的前幾十個字節,將內容輸出在stdout上。
    特別說明: 該程序只適用於2.6.25左右的新版本! 用__get_free_pages分配連續內存時,宜先用get_order獲取頁數,然後需將各頁都SetPageReserved下,同樣地,釋放內存時,需要對每一頁調用ClearPageReserved。
    我成功用該程序分配了4MB共享內存,運行還比較穩定。 因爲linux內核的默認設置,一般情況下用get_free_pages只能分配到4MB內存左右,如需增大,可能需改相應的參數並重新編譯內核。
 
五、內核源碼
 
   1、common.h(內核與用戶空間都用到的頭文件)
#ifndef _COMMON_H_
#define _COMMON_H_
/* protocol type */
#define SHM_NETLINK 30
/* message type */
#define SHM_GET_SHM_INFO 1
/* you can add othe message type here */
#define SHM_WITH_NETLINK "hello, use share memory with netlink"

typedef struct _nlk_msg
{
    union _data
    {
        struct _shm_info
        {
            uint32_t mem_addr;
            uint32_t mem_size;
        }shm_info;
       
        /* you can add other content here */
    }data;
}nlk_msg_t;

#endif /* _COMMON_H_ */
 
2、shm_k.c(內核模塊)
#include <linux/init.h>
#include <linux/module.h>
#include <linux/version.h>
#include <linux/types.h>
#include <linux/skbuff.h>
#include <linux/netlink.h>
#include <net/sock.h>
#include <linux/spinlock.h>
#include "common.h"
#define SHM_TEST_DEBUG
#ifdef SHM_TEST_DEBUG
#define SHM_DBG(args...) printk(KERN_DEBUG "SHM_TEST: " args)
#else
#define SHM_DBG(args...)
#endif
#define SHM_ERR(args...) printk(KERN_ERR "SHM_TEST: " args)

static struct _glb_para
{
    struct _shm_para
    {
        uint32_t mem_addr; /* memory starting address */
        uint32_t mem_size; /* memory size */
        uint32_t page_cnt; /* memory page count*/
        uint16_t order;
        uint8_t mem_init_flag; /* 0, init failed; 1, init successful */
    }shm_para;
   
    struct sock *nlfd; /* netlink descriptor */
    uint32_t pid; /* user-space process's pid */
    rwlock_t lock;
}glb_para;
 
static void init_glb_para(void);
static int init_netlink(void);
static void kernel_receive(struct sk_buff* __skb);
static int nlk_get_mem_addr(struct nlmsghdr *pnhdr);
static void clean_netlink(void);
static int init_shm(void);
static void clean_shm(void);
static int  __init init_shm_test(void);
static void clean_shm_test(void);

static void init_glb_para(void)
{
    memset(&glb_para, 0, sizeof(glb_para));
}
 
static int init_netlink(void)
{
    rwlock_init(&glb_para.lock);
    SHM_DBG("linux version:%08x/n", LINUX_VERSION_CODE);
#if(LINUX_VERSION_CODE < KERNEL_VERSION(2,6,18))
    glb_para.nlfd = netlink_kernel_create(SHM_NETLINK, kernel_receive);
#elif(LINUX_VERSION_CODE < KERNEL_VERSION(2,6,24))
    glb_para.nlfd = netlink_kernel_create(SHM_NETLINK, 0, kernel_receive, THIS_MODULE));
#else
    glb_para.nlfd = netlink_kernel_create(&init_net, SHM_NETLINK, 0, kernel_receive, NULL, THIS_MODULE);
#endif
   
    if(glb_para.nlfd == NULL)
    {
        SHM_ERR("init_netlink::netlink_kernel_create error/n");
        return (-1);
    }
   
    return (0);
}

static void kernel_receive(struct sk_buff* __skb)
{
 struct sk_buff *skb;
    struct nlmsghdr *nlh = NULL;
    int invalid;
   
 SHM_DBG("begin kernel_receive/n");
 skb = skb_get(__skb);
 invalid = 0;
 if(skb->len >= sizeof(struct nlmsghdr))
    {
        nlh = (struct nlmsghdr *)skb->data;
        if((nlh->nlmsg_len >= sizeof(struct nlmsghdr))
            && (skb->len >= nlh->nlmsg_len))
        {
            switch(nlh->nlmsg_type)
            {
                case SHM_GET_SHM_INFO:
                    SHM_DBG("receiv TA_GET_SHM_INFO/n");
              nlk_get_mem_addr(nlh);
                    break;
                default:
                    break;
            }
     }
 }
    kfree_skb(skb);
}
 
static int nlk_get_mem_addr(struct nlmsghdr *pnhdr)
{
    int ret, size;
    unsigned char *old_tail;
    struct sk_buff *skb;
    struct nlmsghdr *nlh;
    struct _nlk_msg *p;
   
   
    glb_para.pid = pnhdr->nlmsg_pid; /* get the user-space process's pid */
   
    size = NLMSG_SPACE(sizeof(struct _nlk_msg)); /* compute the needed memory size */
    if( (skb = alloc_skb(size, GFP_ATOMIC)) == NULL) /* allocate memory */
    {
        SHM_DBG("nlk_hello_test::alloc_skb error./n");
        return (-1);
    }
   
    old_tail = skb->tail;
    nlh = NLMSG_PUT(skb, 0, 0, SHM_GET_SHM_INFO, size-sizeof(struct nlmsghdr)); /* put netlink message structure into memory */
   
    p = NLMSG_DATA(nlh); /* get netlink message body pointer */
    p->data.shm_info.mem_addr = __pa(glb_para.shm_para.mem_addr); /* __pa:convert virtual address to physical address, which needed by /dev/mem */
    p->data.shm_info.mem_size = glb_para.shm_para.mem_size;
   
    nlh->nlmsg_len = skb->tail - old_tail;
    NETLINK_CB(skb).pid = 0;      /* from kernel */
    NETLINK_CB(skb).dst_group = 0;
    read_lock_bh(&glb_para.lock);
    ret = netlink_unicast(glb_para.nlfd, skb, glb_para.pid, MSG_DONTWAIT); /* send message to user-space process */
    read_unlock_bh(&glb_para.lock);
    SHM_DBG("nlk_get_mem_addr ok./n");
    return (ret);
   
nlmsg_failure:
    SHM_DBG("nlmsg_failure/n");
    if(skb)
    {
        kfree_skb(skb);
    }
    return (-1);
}

static void clean_netlink(void)
{
    if(glb_para.nlfd != NULL)
    {
        sock_release(glb_para.nlfd->sk_socket);
    }
}

static int init_shm(void)
{
    int i;
    char *p;
    uint32_t page_addr;
   
    glb_para.shm_para.order = get_order(1024*8); /* allocate 8kB */
    glb_para.shm_para.mem_addr = __get_free_pages(GFP_KERNEL, glb_para.shm_para.order);
    if(glb_para.shm_para.mem_addr == 0)
    {
        SHM_ERR("init_mem_pool::__get_free_pages error./n");
        glb_para.shm_para.mem_init_flag = 0;
        return (-1);
    }
    else
    {
        glb_para.shm_para.page_cnt = (1<<glb_para.shm_para.order);
        glb_para.shm_para.mem_size = glb_para.shm_para.page_cnt*PAGE_SIZE;
        glb_para.shm_para.mem_init_flag = 1;
        page_addr = glb_para.shm_para.mem_addr;
        SHM_DBG("size=%08x, page_cnt=%d/n", glb_para.shm_para.mem_size, glb_para.shm_para.page_cnt);
        for(i = 0; i <  glb_para.shm_para.page_cnt; i++)
        {
            SetPageReserved(virt_to_page(page_addr)); /* reserved for used */
            page_addr += PAGE_SIZE;
        }
       
        p = (char *)glb_para.shm_para.mem_addr;
        strcpy(p, SHM_WITH_NETLINK); /* write */
        SHM_DBG("__get_free_pages ok./n");
    }
   
    return (0);
}

static void clean_shm(void)
{
    int i;
    uint32_t page_addr;
   
    if(glb_para.shm_para.mem_init_flag == 1)
    {
        page_addr = glb_para.shm_para.mem_addr;
        for(i = 0; i < glb_para.shm_para.page_cnt; i++)
        {
            ClearPageReserved(virt_to_page(page_addr));
            page_addr += PAGE_SIZE;
        }
        free_pages(glb_para.shm_para.mem_addr, glb_para.shm_para.order);
    }
}

static int  __init init_shm_test(void)
{
    init_glb_para();
    if(init_netlink() < 0)
    {
        SHM_ERR("init_shm_test::init_netlink error./n");
        return (-1);
    }
    SHM_DBG("init_netlink ok./n");
   
    if(init_shm() < 0)
    {
        SHM_ERR("init_shm_test::init_mem_pool error./n");
        clean_shm_test();
        return (-1);
    }
    SHM_DBG("init_mem_pool ok./n");
   
    return (0);
}

static void clean_shm_test(void)
{
    clean_shm();
    clean_netlink();
   
    SHM_DBG("ta_exit ok./n");
}
module_init(init_shm_test);
module_exit(clean_shm_test);
MODULE_LICENSE("GPL");
MODULE_AUTHOR("bripengandre (
[email protected] )");
MODULE_DESCRIPTION("Memory Share between user-space and kernel-space with netlink.");
 
3、shm_u.c(用戶進程)
#include <stdio.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <linux/netlink.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <sys/types.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
#include "common.h"

/* netlink */
#define MAX_SEND_BUF_SIZE 2500
#define MAX_RECV_BUF_SIZE 2500
#define SHM_TEST_DEBUG
#ifdef SHM_TEST_DEBUG
#define SHM_DBG(args...) fprintf(stderr, "SHM_TEST: " args)
#else
#define SHM_DBG(args...)
#endif
#define SHM_ERR(args...) fprintf(stderr, "SHM_TEST: " args)
struct _glb_para
{
    struct _shm_para
    {
        uint32_t mem_addr;
        uint32_t mem_size;
    }shm_para;
   
    int nlk_fd;
    char send_buf[MAX_SEND_BUF_SIZE];
    char recv_buf[MAX_RECV_BUF_SIZE];
}glb_para;
 

static void init_glb_para(void);
static  int create_nlk_connect(void);
static int nlk_get_shm_info(void);
static int init_mem_pool(void);

int main(int argc ,char *argv[])

    char *p;
   
    init_glb_para();
    if(create_nlk_connect() < 0)
    {
        SHM_ERR("main::create_nlk_connect error./n");
        return (1);
    }
  
    if(nlk_get_shm_info() < 0)
    {
        SHM_ERR("main::nlk_get_shm_info error./n");
        return (1);
    }
   
    init_mem_pool();
    /* printf the first 30 bytes */
    p = (char *)glb_para.shm_para.mem_addr;
    p[strlen(SHM_WITH_NETLINK)] = '/0';
    printf("the first 30 bytes of shm are: %s/n", p);
   
    return (0);
}

static void init_glb_para(void)
{
    memset(&glb_para, 0, sizeof(glb_para));
}

static  int create_nlk_connect(void)
{
    int sockfd;
    struct sockaddr_nl local;
   
    sockfd = socket(PF_NETLINK, SOCK_RAW, SHM_NETLINK);
    if(sockfd < 0)
    {
        SHM_ERR("create_nlk_connect::socket error:%s/n", strerror(errno));
        return (-1);
    }
    memset(&local, 0, sizeof(local));
    local.nl_family = AF_NETLINK;
    local.nl_pid = getpid();
    local.nl_groups = 0;
    if(bind(sockfd, (struct sockaddr*)&local, sizeof(local)) != 0)
    {
        SHM_ERR("create_nlk_connect::bind error: %s/n", strerror(errno));
        return -1;
    }
   
    glb_para.nlk_fd = sockfd;
   
    return (sockfd);
}

static int nlk_get_shm_info(void)
 {
    struct nlmsghdr *nlh;
    struct _nlk_msg *p;
    struct sockaddr_nl kpeer;
    int recv_len, kpeerlen;
   
    memset(&kpeer, 0, sizeof(kpeer));
    kpeer.nl_family = AF_NETLINK;
    kpeer.nl_pid = 0;
    kpeer.nl_groups = 0;
   
    memset(glb_para.send_buf, 0, sizeof(glb_para.send_buf));
    nlh = (struct nlmsghdr *)glb_para.send_buf;
    nlh->nlmsg_len = NLMSG_SPACE(0);
    nlh->nlmsg_flags = 0;
    nlh->nlmsg_type = SHM_GET_SHM_INFO;
    nlh->nlmsg_pid = getpid();
    sendto(glb_para.nlk_fd, nlh, nlh->nlmsg_len, 0, (struct sockaddr*)&kpeer, sizeof(kpeer));
    memset(glb_para.send_buf, 0, sizeof(glb_para.send_buf));
    kpeerlen = sizeof(struct sockaddr_nl);
 recv_len = recvfrom(glb_para.nlk_fd, glb_para.recv_buf, sizeof(glb_para.recv_buf), 0, (struct sockaddr*)&kpeer, &kpeerlen);
 p = NLMSG_DATA((struct nlmsghdr *) glb_para.recv_buf);
 SHM_DBG("%d, errno=%d.%s, %08x, %08x/n", recv_len, errno, strerror(errno), p->data.shm_info.mem_addr,  p->data.shm_info.mem_size);
 glb_para.shm_para.mem_addr = p->data.shm_info.mem_addr;
 glb_para.shm_para.mem_size = p->data.shm_info.mem_size;
 
 return (0);
 }
 
 
static int init_mem_pool(void)
{
    int map_fd;
    void *map_addr;
   
    map_fd = open("/dev/mem", O_RDWR);
    if(map_fd < 0)
    {
        SHM_ERR("init_mem_pool::open %s error: %s/n", "/dev/mem", strerror(errno));
        return (-1);
    }
     
    map_addr = mmap(0, glb_para.shm_para.mem_size, PROT_READ|PROT_WRITE, MAP_SHARED, map_fd, glb_para.shm_para.mem_addr);
    if(map_addr == NULL)
    {
        SHM_ERR("init_mem_pool::mmap error: %s/n", strerror(errno));
        return (-1);
    }
    glb_para.shm_para.mem_addr = (uint32_t)map_addr;
    return (0);
}
4、Makefile
#PREFIX = powerpc-e300c3-linux-gnu-
CC  ?= $(PREFIX)gcc
KERNELDIR ?= /lib/modules/`uname -r`/build

all: modules app
obj-m:= shm_k.o
module-objs := shm_k.c
modules:
 make -C $(KERNELDIR) M=`pwd` modules
app: shm_u.o
 $(CC) -o shm_u shm_u.c

clean:
 rm -rf *.o Module.symvers modules.order shm_u shm_k.ko shm_k.mod.c .tmp_versions .shm_k.*

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章