Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package main
- import (
- "bufio"
- "encoding/binary"
- "fmt"
- "net"
- "os"
- "strconv"
- "strings"
- "sync"
- "time"
- )
- type Client struct {
- socket net.Conn
- username string
- inbox chan string
- }
- type ActiveClients struct {
- mu sync.RWMutex
- clients map[string]*Client
- }
- func (ac *ActiveClients) add(c *Client) bool {
- ac.mu.Lock()
- defer ac.mu.Unlock()
- if ac.clients == nil {
- ac.clients = make(map[string]*Client, 0)
- }
- if c.inbox == nil {
- c.inbox = make(chan string, 100)
- }
- for username := range ac.clients {
- if username == c.username {
- return false
- }
- }
- ac.clients[c.username] = c
- return true
- }
- func (ac *ActiveClients) remove(c Client) bool {
- ac.mu.Lock()
- defer ac.mu.Unlock()
- debug("Removing %v\n", c.username)
- if ac.clients == nil {
- return false
- }
- _, ok := ac.clients[c.username]
- if ok {
- delete(ac.clients, c.username)
- }
- return ok
- }
- func (ac *ActiveClients) forwardMessage(message string, from string) {
- ac.mu.RLock()
- defer ac.mu.RUnlock()
- msgWithFrom := from + ": " + message
- for username, c := range ac.clients {
- if username != from {
- debug("Attempting to forward to %v...", username)
- select {
- case c.inbox <- msgWithFrom:
- debug("Forwarding message to %v\n", username)
- default:
- // if buffer is full, drop the message
- debug("Message dropped because %v inbox is full.\n", c.username)
- }
- }
- }
- }
- var wg sync.WaitGroup
- var shouldDebug bool = false
- func debug(message string, vals ...any) {
- if shouldDebug {
- fmt.Printf(message, vals...)
- }
- }
- func uint16ToBytes(i uint16) []byte {
- // return []byte{byte(i >> 8), byte(i & 0xFF)}
- b := make([]byte, 2)
- binary.BigEndian.PutUint16(b, i)
- return b
- }
- func bytesToUint16(b []byte) uint16 {
- return binary.BigEndian.Uint16(b[:2])
- }
- func sendMessage(message string, connection net.Conn, milliseconds int) (int, error) {
- var timeout time.Duration
- msg_bytes := []byte(message)
- msg_size := len(msg_bytes)
- msg := append(uint16ToBytes(uint16(msg_size)), msg_bytes...)
- if milliseconds > 0 {
- timeout = time.Millisecond * time.Duration(milliseconds)
- connection.SetWriteDeadline(time.Now().Add(timeout))
- defer connection.SetWriteDeadline(time.Time{})
- }
- size, err := connection.Write(msg)
- if err != nil {
- if !strings.Contains(err.Error(), "i/o timeout") {
- fmt.Println(err)
- }
- }
- debug("Sent %v bytes to %v\n", size, connection.RemoteAddr())
- return size, err
- }
- func receiveMessage(connection net.Conn, milliseconds int) (string, int, error) {
- var msg_size [2]byte
- var msg []byte
- var timeout time.Duration
- if milliseconds > 0 {
- timeout = time.Millisecond * time.Duration(milliseconds)
- connection.SetReadDeadline(time.Now().Add(timeout))
- defer connection.SetReadDeadline(time.Time{})
- }
- size, err := connection.Read(msg_size[:])
- if err != nil {
- if !strings.Contains(err.Error(), "i/o timeout") {
- fmt.Println(err)
- }
- }
- if err != nil || size == 0 {
- return "", size, err
- }
- size = int(bytesToUint16(msg_size[:]))
- msg = make([]byte, size)
- size, err = connection.Read(msg)
- if err != nil {
- if !strings.Contains(err.Error(), "i/o timeout") {
- fmt.Println(err)
- }
- }
- if err != nil || size == 0 {
- return "", size, err
- }
- debug("Received message (%v bytes) from %v\n", size+2, connection.RemoteAddr())
- return string(msg), size, err
- }
- func server(numberOfClients int, hostInterface string) {
- shouldDebug = true
- users := ActiveClients{}
- listener, err := net.Listen("tcp", hostInterface)
- if err != nil {
- fmt.Println(err)
- os.Exit(1)
- }
- for i := 0; i < numberOfClients; i++ {
- wg.Add(1)
- go listen(listener, &users)
- }
- wg.Wait()
- }
- func listen(listener net.Listener, users *ActiveClients) {
- defer wg.Done()
- var username string
- var thisClient Client
- var lg sync.WaitGroup
- signalShouldEnd := make(chan bool, 1)
- loop:
- connection, err := listener.Accept()
- if err != nil {
- debug(err.Error() + "\n")
- return
- }
- debug("Received connection from %v\n", connection.RemoteAddr())
- hasSetUsername := false
- lg.Add(2)
- go func() {
- defer lg.Done()
- for {
- // wait up to 100ms for the client to send a message
- message, size, err := receiveMessage(connection, 100)
- if err != nil {
- if strings.Contains(err.Error(), "i/o timeout") {
- // timed out, so go back to start of loop
- // fmt.Println("i/o timeout")
- continue
- }
- debug(err.Error())
- select {
- case signalShouldEnd <- true:
- if users.remove(thisClient) {
- users.forwardMessage("Goodbye "+thisClient.username, "server")
- }
- default:
- }
- return
- }
- if size > 0 {
- if !hasSetUsername {
- username = message
- if username == "server" {
- sendMessage("server: error: username is invalid; please send new username", connection, 0)
- break
- }
- thisClient = Client{
- socket: connection,
- username: username,
- }
- ok := users.add(&thisClient)
- if !ok {
- sendMessage("server: error: username is invalid; please send new username", connection, 0)
- break
- } else {
- users.forwardMessage("Welcome "+username, "server")
- }
- hasSetUsername = true
- } else {
- users.forwardMessage(message, username)
- }
- }
- select {
- case <-signalShouldEnd:
- return
- default:
- }
- }
- }()
- go func() {
- defer lg.Done()
- for {
- select {
- case incoming := <-thisClient.inbox:
- debug("Message found in inbox; sending to %v\n", thisClient.username)
- _, err := sendMessage(incoming, thisClient.socket, 0)
- if err != nil {
- if !strings.Contains(err.Error(), "i/o timeout") {
- debug(err.Error() + "\n")
- select {
- case signalShouldEnd <- true:
- if users.remove(thisClient) {
- users.forwardMessage("Goodbye "+thisClient.username, "server")
- }
- default:
- }
- return
- }
- continue
- }
- case <-signalShouldEnd:
- return
- default:
- time.Sleep(time.Millisecond * 10)
- }
- }
- }()
- lg.Wait()
- debug("Recycling server socket for new client\n")
- goto loop
- }
- func client(host string, username string) {
- connection, err := net.Dial("tcp", host+":1337")
- endSignals := []string{"bye", "quit", "exit"}
- if err != nil {
- debug("%v\n", err)
- return
- }
- go func() {
- for {
- message, size, err := receiveMessage(connection, 100)
- if err != nil {
- if strings.Contains(err.Error(), "i/o timeout") {
- continue
- }
- debug("%v\n", err)
- os.Exit(0)
- }
- if size == 0 {
- continue
- }
- // fmt.Printf("Received %v bytes: ", size)
- fmt.Println(message)
- }
- }()
- fmt.Println("Connected to server.")
- _, err = sendMessage(username, connection, 0)
- if err != nil {
- fmt.Println(err)
- return
- }
- for {
- message := askInput("")
- if contains(endSignals, message) {
- return
- }
- _, err := sendMessage(message, connection, 0)
- if err != nil {
- fmt.Println(err)
- return
- }
- }
- }
- func benchmark(host string, numberOfClients int, numberOfMessages int, baseMessage string) {
- wg.Add(numberOfClients)
- for i := 0; i < numberOfClients; i++ {
- go func(i int) {
- defer wg.Done()
- var bg sync.WaitGroup
- username := "benchmark" + strconv.Itoa(i)
- connection, err := net.Dial("tcp", host+":1337")
- if err != nil {
- fmt.Println(err)
- return
- }
- // current limitation: must wait for username to process
- _, _ = sendMessage(username, connection, 100)
- time.Sleep(time.Millisecond * 100)
- bg.Add(numberOfMessages)
- for d := 0; d < numberOfMessages; d++ {
- go func(message string) {
- defer bg.Done()
- _, err = sendMessage(message, connection, 100)
- if err != nil {
- fmt.Println(err)
- return
- }
- }(baseMessage + strconv.Itoa(d))
- }
- bg.Add(1)
- go func() {
- defer bg.Done()
- for {
- message, size, err := receiveMessage(connection, 2000)
- if err != nil {
- fmt.Println(err)
- return
- }
- if size == 0 {
- fmt.Println("aborting from empty receive")
- return
- }
- fmt.Println(message)
- }
- }()
- bg.Wait()
- }(i)
- }
- wg.Wait()
- }
- func askInput(query string) string {
- if query != "" {
- fmt.Printf("%v ", query)
- }
- in := bufio.NewReader(os.Stdin)
- data, _ := in.ReadString('\n')
- data, _, _ = strings.Cut(data, "\n")
- data, _, _ = strings.Cut(data, "\r")
- return data
- }
- func contains[T comparable](list []T, query T) bool {
- // contains function adapted from https://stackoverflow.com/a/10485970
- for _, item := range list {
- if item == query {
- return true
- }
- }
- return false
- }
- func usage(programName string) {
- fmt.Printf("usage: %v server numerOfThreads [hostInterface]\n", programName)
- fmt.Println("\thostInterface is of form hostname:port or IP:port")
- fmt.Println("\tuse 0.0.0.0:port to listen on every available IP interface")
- fmt.Println("\tdefault is localhost:1337")
- fmt.Printf("usage: %v client hostIP username\n", programName)
- fmt.Printf("usage: %v benchmark hostIP numberOfClients numberOfMessages baseMessage\n", programName)
- fmt.Println("arguments must be in order shown")
- }
- func main() {
- programName := os.Args[0]
- nameParts := strings.Split(programName, "\\")
- programName = nameParts[len(nameParts)-1]
- if len(os.Args) < 2 {
- usage(programName)
- return
- }
- mode := os.Args[1]
- switch mode {
- case "server", "serve":
- var numberOfClients int
- if len(os.Args) < 3 {
- usage(programName)
- return
- }
- n, _ := strconv.ParseInt(os.Args[2], 10, 16)
- numberOfClients = int(n)
- var hostInterface string
- if len(os.Args) > 3 {
- hostInterface = os.Args[3]
- } else {
- hostInterface = "localhost:1337"
- }
- server(numberOfClients, hostInterface)
- case "client", "connect":
- var host string
- var username string
- if len(os.Args) < 4 {
- usage(programName)
- return
- }
- host = os.Args[2]
- username = os.Args[3]
- client(host, username)
- case "benchmark":
- var host string
- var nClients string
- var numberOfClients int
- var nMsgs string
- var numberOfMessages int
- var message string
- if len(os.Args) < 6 {
- usage(programName)
- return
- }
- host = os.Args[2]
- nClients = os.Args[3]
- nMsgs = os.Args[4]
- message = os.Args[5]
- n, _ := strconv.ParseInt(nClients, 10, 16)
- numberOfClients = int(n)
- n, _ = strconv.ParseInt(nMsgs, 10, 16)
- numberOfMessages = int(n)
- benchmark(host, numberOfClients, numberOfMessages, message)
- default:
- usage(programName)
- }
- }
- // ISC License
- // Copyleft (c) 2023 k98kurz
- // Permission to use, copy, modify, and/or distribute this software
- // for any purpose with or without fee is hereby granted, provided
- // that the above copyleft notice and this permission notice appear in
- // all copies.
- // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
- // WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
- // WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
- // AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
- // CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
- // OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
- // NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
- // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement