golang實現文件傳輸

服務段代碼:

package main

import (
	"bytes"
	"fmt"
	"io/ioutil"
	"net"
	"os"
	"os/exec"
	"runtime"
	"strconv"
	"strings"
	"syscall"
)

func initEnv(i ...int) {
	c:=1
	if len(i) == 0 {
		if runtime.GOOS=="linux"{
			t, _ := exec.Command("/bin/bash","-c","cat /proc/cpuinfo | grep 'processor' | sort | uniq | wc -l").Output()
			c, _ =strconv.Atoi(string(t))
		}else {
			c, _ = strconv.Atoi(os.Getenv("number_of_processors"))
		}
	}else{
		c=i[0]
	}
	fmt.Println(c)
	runtime.GOMAXPROCS(c)
}

func sendAll(conn net.Conn, data []byte) {
	length := len(data)
	count := 0
	for count < length {
		n, _ := conn.Write(data[count:])
		count += n
	}
}

func getIndex(str string, c rune) int {
	for i, s := range str {
		if s == c {
			return i
		}
	}
	return -1
}

func rIndex(str string, c uint8) int {

	for i := len(str) - 1; i >= 0; i-- {
		if str[i] == c {
			return i
		}
	}
	return -1
}

func getDir(path string) string {

	path = strings.ReplaceAll(path, "\\", "/")
	dir := path[:rIndex(path, '/')]
	return dir
}

func handle(conn net.Conn) {
	//_ = conn.SetReadDeadline(time.Now().Add(time.Second * 60*5))
	//_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 60*5))
	defer func() {
		if err := recover(); err != nil {
			fmt.Println(err)
			fmt.Println("exit")
			return
		}
	}()
	defer conn.Close()
	data := make([]byte, 40960)
	n, err := conn.Read(data)
	if err != nil {
		return
	}
	res := bytes.Split(data[:n], []byte("|"))
	if len(res) < 2 {
		return
	}
	mode := string(res[0])
	file := string(res[1])
	down := "download"
	upload := "upload"
	if mode == down {
		//下載處理
		fmt.Println("download.....", string(file))
		f, _ := os.Open(string(file))
		defer f.Close()
		if f != nil { //file is not nil , and then send file data
			var buffer = make([]byte, 1024)
			n, _ := f.Read(buffer)
			for n != 0 {
				sendAll(conn, buffer[:n])
				n, err = f.Read(buffer)
			}
		} else {
			fmt.Println("讀取文件失敗", string(file))
		}

	} else if mode == upload {
		//上傳處理
		fmt.Println("upload...", file)
		//開始創建保存目錄
		dir := getDir(file)
		fmt.Println("dir:", dir)
		_, err = os.Stat(dir)
		if os.IsNotExist(err) {
			e := os.MkdirAll(dir, os.ModePerm)
			if e != nil {
				fmt.Printf("不能創建目錄")
				panic(e)
			}
		}
		var f  *os.File
		var err error
		if b,_:=pathExists(file);b{
			fmt.Println(b,"exist")
			f, err = os.OpenFile(file,syscall.O_WRONLY,0)
		}else {
			f, err = os.Create(file) //創建文件
		}
		defer func() {
			if f != nil {
				f.Close()
			}
		}()
		if err != nil {
			fmt.Println("文件創建失敗",file)
			panic(fmt.Sprintf("%v%s",err,"文件創建失敗"))
		}
		buffer := make([]byte, 1024)
		for {
			n, err := conn.Read(buffer)
			if err != nil {
				break
			}
			_, _ = f.Write(buffer[:n])
		}
		fmt.Println("upload ok")
	}
	fmt.Println("done")

}

func pathExists(path string) (bool, error) {
	_, err := os.Stat(path)
	if err == nil {
		return true, nil
	}
	if os.IsNotExist(err) {
		return false, nil
	}
	return false, err
}

func run(host,port,sysType,interpreter,arg1,cmdGetPid string)  {
start:
	tcpServer, _ := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%s",host,port))
	listener, _ := net.ListenTCP("tcp", tcpServer)
	fmt.Println("run server ... ")
	retry:=20
	for {
		//當有新的客戶端請求來的時候,拿到與客戶端的連接
		conn, err := listener.Accept()
		if err != nil {
			if retry<=0{
				panic("端口被佔用")
			}
			if sysType == "linux"{
				_ = exec.Command(interpreter,arg1,cmdGetPid).Run()
				retry--
				goto start
			}
			getPID:=exec.Command(interpreter,arg1,fmt.Sprintf(cmdGetPid,port))
			byteRes, _ :=getPID.Output()
			res:=string(byteRes)
			split:="\n"
			if strings.Contains(res,"\r\n"){
				split="\r\n"
			}
			pidInfoList :=strings.Split(res,split)
			pidSet :=make(map[string]bool)
			for _,line:=range pidInfoList {
				if len(line)<3{
					continue
				}
				pid:=line[rIndex(line,' ')+1:]
				if _,exist:= pidSet[pid];!exist{
					kill:=exec.Command("cmd.exe","/c",fmt.Sprintf("taskkill /f /t /im %s",pid))
					_ = kill.Run()
					fmt.Println(pid)
					pidSet[pid]=true
				}
			}
			retry--
			goto start
		}
		go handle(conn)
	}
}

func GetAllFile(res **[]string ,pathname string) error {
	rd, err := ioutil.ReadDir(pathname)
	for _, fi := range rd {
		if fi.IsDir() {
			_ = GetAllFile(res,pathname + "/" + fi.Name() )
		} else {
			t:=append(**res,pathname+"/"+fi.Name())
			*res=&t
		}
	}
	return err
}

func toList(files []string) string  {
	var buffer bytes.Buffer
	for _,s:=range files {
		buffer.WriteString(fmt.Sprintf("'%s',",s))
	}
	res:="["+buffer.String()+"]"
	fmt.Println(res)
	return res
}

func main() {
	initEnv()
	host:=""
	port:="8080"
	portControl :="8081"
	if len(os.Args)>3{
		host=os.Args[1]
		port=os.Args[2]
	}
	sysType := runtime.GOOS
	fmt.Println(sysType)
	interpreter :="cmd.exe"
	arg1:="/c"
	cmdGetPid:="netstat -ano|findstr %s"
	if sysType == "linux" {
		// LINUX系統
		interpreter="/bin/bash"
		arg1="-c"
		cmdGetPid="a=$(netstat -anp|grep %s|awk -F '[ /]' '{print $(NF-1)}');for i in ${a[@]};do;kill $i;done"
	}
	go run(host,port,sysType,interpreter,arg1,cmdGetPid)
	tcpServer, _ := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%s",host, portControl))
	listener, _ := net.ListenTCP("tcp", tcpServer)
	buffer:=make([]byte,4096)
	for {
		conn, err := listener.Accept()
		if err != nil {
			return
		}
		n,_:=conn.Read(buffer)
		res:=string(buffer[:n])
		var code, args string
		list:=strings.Split(res,"|")
		fmt.Println(list,len(list))
		if len(list) == 1 {
			code=list[0]
		}else if len(list)>1{
			fmt.Println("len(list)>1")
			code=list[0]
			args=list[1]
		}else {
			continue
		}
		fmt.Println(res,code,list)
		switch code {
		case "get_files":
			fmt.Println(args)
			t:=make([]string,0)
			files:=&t
			pFiles :=&files
			_ = GetAllFile(pFiles,args)
			//fmt.Print(*files)
			sendAll(conn, []byte(toList(*files)))
			conn.Close()
			break
		case "exit":
			fmt.Println("exit ----------------------")
			return
		}
	}
}

客戶端:


import os
import socket
import threading


def download(target, local):
    if not os.path.isfile(target):
        return
    print(target, local)
    client = socket.socket()
    while True:
        try:
            client.connect(('127.0.0.1', 8080))
            break
        except Exception as e:
            print(e)
            pass
    client.sendall(bytes("download|%s" % target, encoding='utf8'))
    dir_name, _ = os.path.split(local)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name, exist_ok=True)
    with open(local, 'wb')as f:
        fragment = client.recv(1024)
        while fragment:
            f.write(fragment)
            fragment = client.recv(1024)
        print('退出了')
    client.close()


def upload(local, target):
    client = socket.socket()
    client.connect(('127.0.0.1', 8080))
    client.sendall(bytes('upload|%s' % target, encoding='utf8'))
    with open(local, 'rb')as f:
        content = f.read(1024)
        while content:
            client.sendall(content)
            content = f.read(1024)
    client.close()


def get_files(path):
    client = socket.socket()
    client.connect(('127.0.0.1', 8081))
    path = "get_files|%s" % path
    client.sendall(path.encode('utf8'))
    files = []
    content = client.recv(1024)
    while content:
        files.append(content)
        content = client.recv(1024)
    client.close()
    return b''.join(files)


def close():
    client = socket.socket()
    client.connect(('127.0.0.1', 8081))
    client.sendall(b'exit')


# file_list = get_files('G:/Driver驅動')
file_list = get_files('G:/test000')
save_list = map(lambda x: "F" + x[1:], file_list)
threads = []
for i in zip(file_list, save_list):
    threads.append(threading.Thread(target=download, args=i))
for t in threads:
    t.setDaemon(False)
    t.start()

res = get_files('G:/test000')
for i in eval(res):
    print(i)
close()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章