Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package main
- import (
- "bufio"
- "flag"
- "fmt"
- "log"
- "net"
- "strings"
- "sync"
- )
- func main() {
- log.SetPrefix("chat server: ")
- addr := flag.String("addr", ":4000", "listen address")
- flag.Parse()
- server := new(server)
- log.Fatal(server.listenAndServe(*addr))
- }
- type server struct {
- addToChannel chan *client
- addUsername chan *client
- remUsername chan *client
- remFromChannel chan *client
- changeChannel chan *client
- remChannel chan bool
- usernameList map[string]*client
- channelList map[string]*channel
- }
- type channel struct {
- addClient chan *client
- remClient chan *client
- broadcast chan string
- name string
- server *server
- }
- type client struct {
- username string
- conn *net.Conn
- currentChannel *channel
- newChannel string
- reader *bufio.Reader
- message chan string
- server *server
- id string
- }
- func (server *server) listenAndServe(addr string) error {
- ln, err := net.Listen("tcp", addr)
- if err != nil {
- return err
- }
- defer ln.Close()
- go server.manageServer()
- log.Printf("listening on %s\n", addr)
- for {
- conn, err := ln.Accept()
- if err != nil {
- return err
- }
- log.Printf("new connection from %s\n", conn.RemoteAddr().String())
- go initializeClient(&conn, server)
- }
- }
- func initializeClient(conn *net.Conn, server *server) {
- log.Printf("initalizing %s\n", (*conn).RemoteAddr().String())
- client := client{conn: conn, reader: bufio.NewReader(*conn), server: server, id: "@" + (*conn).RemoteAddr().String(), message: make(chan string)}
- _, err := fmt.Fprint(*client.conn, "welcome to the chat server\n")
- if err != nil {
- client.shutdown()
- return
- }
- client.getUsername()
- }
- func (client *client) shutdown() {
- if client.currentChannel != nil {
- client.server.remFromChannel <- client
- }
- if client.username != "" {
- client.server.remUsername <- client
- }
- log.Printf("%s shutting down", client.id)
- (*client.conn).Close()
- log.Printf("%s shut down", client.id)
- }
- func getTrimmed(client *client, what string) (fromUser string, err error) {
- _, err = fmt.Fprint(*client.conn, what+": ")
- if err != nil {
- client.shutdown()
- return
- }
- fromUser, err = client.reader.ReadString('\n')
- if err != nil {
- client.shutdown()
- return
- }
- return strings.TrimSpace(fromUser), err
- }
- func (client *client) getUsername() {
- var err error
- log.Printf("getting username for %s\n", client.id)
- client.username, err = getTrimmed(client, "username")
- if err != nil {
- return
- }
- client.server.addUsername <- client
- }
- func (client *client) joinChannel(channelName string) {
- log.Printf("getting channel for %s\n", client.id)
- if channelName == "" {
- var err error
- channelName, err = getTrimmed(client, "channel name")
- if err != nil {
- return
- }
- }
- client.currentChannel = &channel{name: channelName, server: client.server}
- client.server.addToChannel <- client
- }
- func (server *server) manageServer() {
- server.addUsername = make(chan *client)
- server.addToChannel = make(chan *client)
- server.remUsername = make(chan *client)
- server.remFromChannel = make(chan *client)
- server.remChannel = make(chan bool)
- server.changeChannel = make(chan *client)
- server.usernameList = make(map[string]*client)
- server.channelList = make(map[string]*channel)
- remFromChannel := func(client *client) {
- client.currentChannel.remClient <- client
- if true, _ := <-server.remChannel; true {
- log.Printf("channel %s shutting down", client.currentChannel.name)
- delete(server.channelList, client.currentChannel.name)
- }
- }
- for {
- select {
- case c := <-server.addUsername:
- if _, used := server.usernameList[c.username]; used {
- _, err := fmt.Fprintf(*c.conn, "username %s is not available\n", c.username)
- if err != nil {
- c.shutdown()
- break
- }
- go c.getUsername()
- break
- }
- server.usernameList[c.username] = c
- c.id = c.username + c.id
- log.Printf("got username for %s", c.id)
- go c.joinChannel("")
- case c := <-server.addToChannel:
- log.Printf("got channel %s for %s\n", c.currentChannel.name, c.id)
- if _, exists := server.channelList[c.currentChannel.name]; exists {
- c.currentChannel = server.channelList[c.currentChannel.name]
- c.currentChannel.addClient <- c
- break
- }
- log.Printf("creating channel %s", c.currentChannel.name)
- server.channelList[c.currentChannel.name] = c.currentChannel
- c.currentChannel.addClient = make(chan *client)
- go c.currentChannel.manageChannel()
- c.currentChannel.addClient <- c
- log.Printf("created channel %s", c.currentChannel.name)
- case c := <-server.remUsername:
- delete(server.usernameList, c.username)
- case c := <-server.remFromChannel:
- remFromChannel(c)
- case c := <-server.changeChannel:
- remFromChannel(c)
- go c.joinChannel(c.newChannel)
- }
- }
- }
- type clientList struct {
- m (map[string]*client)
- sync.RWMutex
- }
- func (channel *channel) manageChannel() {
- channel.broadcast = make(chan string)
- channel.remClient = make(chan *client)
- clientList := clientList{m: make(map[string]*client)}
- broadcastLoop := func() {
- for {
- message := <-channel.broadcast
- clientList.RLock()
- for _, c := range clientList.m {
- c.message <- message
- }
- clientList.RUnlock()
- }
- }
- go broadcastLoop()
- for {
- select {
- case client := <-channel.addClient:
- fmt.Fprintf(*client.conn, "joining channel %s\n", channel.name)
- client.currentChannel = channel
- clientList.m[client.username] = client
- log.Printf("%s has joined channel %s", client.id, channel.name)
- go client.writeLoop()
- channel.broadcast <- "+++ "+client.username+" has joined the channel\n"
- go client.readLoop()
- case client := <-channel.remClient:
- fmt.Fprintf(*client.conn, "leaving channel %s\n", channel.name)
- delete(clientList.m, client.username)
- log.Printf("%s has left channel %s", client.id, channel.name)
- channel.broadcast <- "--- "+client.username+" has left the channel\n"
- if len(clientList.m) == 0 {
- channel.server.remChannel <- true
- } else {
- channel.server.remChannel <- false
- }
- }
- }
- }
- func (client *client) readLoop() {
- for {
- payload, err := client.reader.ReadString('\n')
- if err != nil {
- client.shutdown()
- return
- }
- if strings.Contains(payload, "/chch") {
- log.Printf("changing channel for %s", client.id)
- client.newChannel = strings.TrimSpace(payload[6:])
- client.server.changeChannel <- client
- return
- }
- log.Printf("broadcast: >>> %s: %s", client.id, payload)
- client.currentChannel.broadcast <- ">>> " + client.username + ": " + payload
- }
- }
- func (client *client) writeLoop() {
- for {
- message := <-client.message
- _, err := fmt.Fprintf(*client.conn, message)
- if err != nil {
- client.shutdown()
- return
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment