Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package main
- import (
- "encoding/json"
- "fmt"
- "io/ioutil"
- "net"
- "crypto/sha256"
- "io"
- "crypto/rand"
- "crypto/rc4"
- "bytes"
- "encoding/binary"
- "sync"
- "time"
- "runtime"
- pseudoRand "math/rand"
- "os"
- "os/signal"
- "crypto/md5"
- )
- func errorPanic(err error, format string, a ...interface{}) {
- if err != nil {
- fmt.Errorf(format, a)
- panic(err)
- }
- }
- type vxlanHeader struct {
- flags uint16
- group uint16
- vxlanId uint32
- }
- type Config struct {
- UpLinkPort int `json:"upLocalPort"`
- UpLinkRemote string `json:"upRemoteIp"`
- UpLinkRemotePort int `json:"upRemotePort"`
- LocalSocketPoolSize int `json:"poolSize"`
- LocalPoolUpdateTime int `json:"poolUpdate"`
- DownLinkPort int `json:"downPort"`
- DownLinkRemote string `json:"downRemoteIp"`
- DownLinkRemotePort int `json:"downRemotePort"`
- Key string `json:"key"`
- VxlanId uint32 `json:"vxlanId"`
- }
- func (c *Config) load(f string) {
- jsonBytes, err := ioutil.ReadFile(f)
- errorPanic(err, "open file %s error %+v \n", f, err)
- err = json.Unmarshal(jsonBytes, c)
- errorPanic(err, "Unmarshal file %s error %+v \n", f, err)
- }
- const nonceSize int = 8
- const sumSize int = md5.Size
- type NotSafeCipher struct {
- key []byte
- vxlanBytes []byte
- }
- func newRc4(k string, vxlanId uint32) (*NotSafeCipher) {
- ret := NotSafeCipher{}
- keyHash := sha256.Sum256([]byte(k))
- ret.key = keyHash[:]
- h := &vxlanHeader{flags: 0x0800, group: 0, vxlanId: vxlanId * 256}
- buf := new(bytes.Buffer)
- err := binary.Write(buf, binary.BigEndian, h)
- errorPanic(err, "binary.Write error \n")
- binary.Write(buf, binary.BigEndian, h)
- ret.vxlanBytes = buf.Bytes()
- return &ret
- }
- func (c *NotSafeCipher) getNonceKey(nonce []byte) []byte {
- sha := sha256.New()
- sha.Write(nonce)
- sha.Write(c.key)
- return sha.Sum(nil)
- }
- func (c *NotSafeCipher) Encrypt(udpPackage []byte) error {
- plainData := udpPackage[nonceSize:]
- nonce := udpPackage[:nonceSize]
- udpData := udpPackage[nonceSize:len(udpPackage)-sumSize]
- sumData := udpPackage[len(udpPackage)-sumSize:]
- if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
- fmt.Errorf("rand.Reader error %+v \n", err)
- return err
- }
- nonceRc4, err := rc4.NewCipher(c.getNonceKey(nonce))
- if err != nil {
- fmt.Errorf("rc4.NewCipher error %+v \n", err)
- return err
- }
- checkSum := md5.Sum(udpData)
- copy(sumData, checkSum[:])
- nonceRc4.XORKeyStream(plainData, plainData)
- return nil
- }
- type checkSumError struct {
- }
- func (checkSumError) Error() string {
- return "checkSumError"
- }
- func (c *NotSafeCipher) Decrypt(udpPackage []byte) error {
- cipherData := udpPackage[nonceSize:]
- nonce := udpPackage[:nonceSize]
- nonceRc4, err := rc4.NewCipher(c.getNonceKey(nonce))
- if err != nil {
- fmt.Errorf("rc4.NewCipher error %+v \n", err)
- return err
- }
- nonceRc4.XORKeyStream(cipherData, cipherData)
- udpData := udpPackage[nonceSize:len(udpPackage)-sumSize]
- sumData := udpPackage[len(udpPackage)-sumSize:]
- checkSum := md5.Sum(udpData)
- if bytes.Equal(sumData, checkSum[:]) == false {
- fmt.Errorf("check sum error\n")
- return checkSumError{}
- }
- copy(nonce, c.vxlanBytes)
- return nil
- }
- type UpLinkSrever struct {
- listenConn *net.UDPConn
- clientPool []*net.UDPConn
- sync.RWMutex
- rc4 *NotSafeCipher
- conf *Config
- plainChan chan []byte
- cipherChan chan []byte
- stopChan chan interface{}
- }
- func getUdpConn(port int) (ret *net.UDPConn, err error) {
- ret, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: port})
- return
- }
- func newUpServer(c *Config) *UpLinkSrever {
- var err error = nil
- ret := &UpLinkSrever{}
- ret.conf = c
- ret.listenConn, err = getUdpConn(c.UpLinkPort)
- errorPanic(err, "creat UpLinkPort listenConn error \n")
- err = ret.upDatePool()
- errorPanic(err, "creat pool error \n")
- ret.rc4 = newRc4(c.Key, c.VxlanId)
- ret.stopChan = make(chan interface{}, 2)
- ret.plainChan = make(chan []byte, 5000)
- ret.cipherChan = make(chan []byte, 5000)
- go ret.handleSend()
- for i := 0; i < runtime.NumCPU(); i++ {
- go ret.handleEncrypt()
- }
- go ret.handleListen()
- go ret.updateTimer()
- return ret
- }
- func (s *UpLinkSrever) handleSend() {
- cipherChan := s.cipherChan
- poolSize := s.conf.LocalSocketPoolSize
- remoteAddr := &net.UDPAddr{IP: net.ParseIP(s.conf.UpLinkRemote), Port: s.conf.UpLinkRemotePort}
- stopChan := s.stopChan
- for {
- select {
- case udpPackage, ok := <-cipherChan:
- if ok == false {
- fmt.Errorf("upLink cipherChan broken, exit\n")
- return
- }
- pos := pseudoRand.Intn(poolSize)
- s.RLock()
- conn := s.clientPool[pos]
- if _, err := conn.WriteToUDP(udpPackage, remoteAddr); err != nil {
- fmt.Errorf("up link conn.WriteToUDP error %+v \n", err)
- }
- s.RUnlock()
- case <-stopChan:
- return
- }
- }
- }
- func (s *UpLinkSrever) handleEncrypt() {
- plainChan := s.plainChan
- cipherChan := s.cipherChan
- rc4Cipher := s.rc4
- stopChan := s.stopChan
- for {
- select {
- case udpPackage, ok := <-plainChan:
- if ok == false {
- fmt.Errorf("upLink plainChan broken, exit\n")
- return
- }
- if err := rc4Cipher.Encrypt(udpPackage); err != nil {
- fmt.Errorf("Encrypt error, drop package \n")
- } else {
- cipherChan <- udpPackage
- }
- case <-stopChan:
- return
- }
- }
- }
- func (s *UpLinkSrever) handleListen() {
- listenConn := s.listenConn
- plainChan := s.plainChan
- for {
- data := make([]byte, 1500)
- n, _, err := listenConn.ReadFromUDP(data)
- if err != nil {
- fmt.Errorf("error in upLink listenConn.ReadFromUDP %+v \n", err)
- return
- }
- if (n + sumSize) > 1500 {
- fmt.Errorf("package is too large, pls reduce your vxlan mtu\n")
- continue
- }
- plainChan <- data[:n + sumSize]
- }
- }
- func (s *UpLinkSrever) updateTimer() {
- delay := time.Duration(s.conf.LocalPoolUpdateTime)
- t := time.NewTimer(delay * time.Second)
- stopChan := s.stopChan
- for {
- select {
- case <-t.C:
- if err := s.upDatePool(); err != nil {
- fmt.Errorf("error in upDatePool %+v \n", err)
- t.Reset(delay / 2 * time.Second)
- } else {
- t.Reset(delay * time.Second)
- }
- case <-stopChan:
- return
- }
- }
- }
- func (s *UpLinkSrever) upDatePool() (err error) {
- err = nil
- newPool := make([]*net.UDPConn, s.conf.LocalSocketPoolSize)
- defer func() {
- if newPool == nil {
- return
- }
- for _, conn := range newPool {
- if conn != nil {
- conn.Close()
- }
- }
- }()
- for i := range newPool {
- newPool[i], err = getUdpConn(0)
- if err != nil {
- return
- }
- }
- s.Lock()
- defer s.Unlock()
- s.clientPool, newPool = newPool, s.clientPool
- return
- }
- func (s *UpLinkSrever) getListenConn() *net.UDPConn {
- return s.listenConn
- }
- func (s *UpLinkSrever) stop() {
- close(s.stopChan)
- close(s.plainChan)
- close(s.cipherChan)
- s.listenConn.Close()
- for _, conn := range s.clientPool {
- conn.Close()
- }
- }
- type DownLinkSrever struct {
- listenConn *net.UDPConn
- sendConn *net.UDPConn
- rc4 *NotSafeCipher
- conf *Config
- plainChan chan []byte
- cipherChan chan []byte
- stopChan chan interface{}
- }
- func newDownServer(c *Config, sendConn *net.UDPConn) *DownLinkSrever {
- ret := &DownLinkSrever{}
- var err error = nil
- ret.sendConn = sendConn
- ret.listenConn, err = getUdpConn(c.DownLinkPort)
- errorPanic(err, "creat DownLinkPort listenConn error \n")
- ret.rc4 = newRc4(c.Key, c.VxlanId)
- ret.conf = c
- ret.plainChan = make(chan []byte, 5000)
- ret.cipherChan = make(chan []byte, 5000)
- ret.stopChan = make(chan interface{}, 2)
- go ret.handleSend()
- for i := 0; i < runtime.NumCPU(); i++ {
- go ret.handleEncrypt()
- }
- go ret.handleListen()
- return ret
- }
- func (s *DownLinkSrever) handleSend() {
- plainChan := s.plainChan
- stopChan := s.stopChan
- conn := s.sendConn
- remoteAddr := &net.UDPAddr{IP: net.ParseIP(s.conf.DownLinkRemote), Port: s.conf.DownLinkRemotePort}
- for {
- select {
- case udpPackage, ok := <-plainChan:
- if ok == false {
- fmt.Errorf("downLink plainChan broken \n")
- return
- }
- if _, err := conn.WriteToUDP(udpPackage, remoteAddr); err != nil {
- fmt.Errorf("up link conn.WriteToUDP error %+v \n", err)
- }
- case <- stopChan:
- return
- }
- }
- }
- func (s *DownLinkSrever) handleEncrypt() {
- plainChan := s.plainChan
- cipherChan := s.cipherChan
- rc4Cipher := s.rc4
- stopChan := s.stopChan
- for {
- select {
- case udpPackage, ok := <-cipherChan:
- if ok == false {
- fmt.Errorf("upLink plainChan broken\n")
- return
- }
- if len(udpPackage) < nonceSize + sumSize + 16 {
- fmt.Errorf("size error, drop package \n")
- continue
- }
- if err := rc4Cipher.Decrypt(udpPackage); err != nil {
- fmt.Errorf("Encrypt error, drop package \n")
- } else {
- plainChan <- udpPackage[:len(udpPackage) - sumSize]
- }
- case <-stopChan:
- return
- }
- }
- }
- func (s *DownLinkSrever) handleListen() {
- listenConn := s.listenConn
- cipherChan := s.cipherChan
- for {
- data := make([]byte, 1500)
- n, _, err := listenConn.ReadFromUDP(data)
- if err != nil {
- fmt.Errorf("error in downLink listenConn.ReadFromUDP %+v \n", err)
- return
- }
- cipherChan <- data[:n]
- }
- }
- func (s *DownLinkSrever) stop() {
- close(s.stopChan)
- close(s.plainChan)
- close(s.cipherChan)
- s.listenConn.Close()
- }
- func main() {
- pseudoRand.Seed(time.Now().Unix())
- c := &Config{}
- if len(os.Args) > 1 {
- c.load(os.Args[1])
- } else {
- c.load("./conf.json")
- }
- up := newUpServer(c)
- down := newDownServer(c, up.getListenConn())
- signalChan := make(chan os.Signal, 2)
- signal.Notify(signalChan, os.Interrupt)
- <-signalChan
- up.stop()
- down.stop()
- }
Add Comment
Please, Sign In to add comment