From 0e6cd289253e3c8c39c39e5ca29cbce08385a5cd Mon Sep 17 00:00:00 2001 From: ChronosX88 Date: Tue, 18 Jan 2022 20:26:37 +0300 Subject: [PATCH] Implement session pool, refactor connection/command handling --- go.mod | 1 + internal/protocol/capabilities.go | 14 ++- internal/protocol/constants.go | 2 +- internal/server/handler.go | 115 ++++++++++++++++++++ internal/server/nntp_server.go | 171 ++++++------------------------ internal/server/session.go | 86 +++++++++++++++ 6 files changed, 248 insertions(+), 141 deletions(-) create mode 100644 internal/server/handler.go create mode 100644 internal/server/session.go diff --git a/go.mod b/go.mod index bda1feb..2db322b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.17 require ( github.com/BurntSushi/toml v1.0.0 + github.com/google/uuid v1.3.0 github.com/jmoiron/sqlx v1.3.4 github.com/mattn/go-sqlite3 v1.14.10 github.com/pressly/goose/v3 v3.5.0 diff --git a/internal/protocol/capabilities.go b/internal/protocol/capabilities.go index d7f04c8..fdf7703 100644 --- a/internal/protocol/capabilities.go +++ b/internal/protocol/capabilities.go @@ -54,13 +54,21 @@ type Capability struct { type Capabilities []Capability -func (cs Capabilities) Add(c Capability) { - for _, v := range cs { +func (cs *Capabilities) Add(c Capability) { + for _, v := range *cs { if v.Type == c.Type { return // allowed only unique items } } - cs = append(cs, c) + *cs = append(*cs, c) +} + +func (cs *Capabilities) Remove(ct CapabilityType) { + for i, v := range *cs { + if v.Type == ct { + *cs = append((*cs)[:i], (*cs)[i+1:]...) + } + } } func (cs Capabilities) String() string { diff --git a/internal/protocol/constants.go b/internal/protocol/constants.go index 68ba9e6..09e76a6 100644 --- a/internal/protocol/constants.go +++ b/internal/protocol/constants.go @@ -31,6 +31,6 @@ const ( MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited" MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally" MessageUnknownCommand = "500 Unknown command" - MessageErrorHappened = "403 Failed to process command: " + MessageErrorHappened = "403 Failed to process command:" MessageListOfNewsgroupsFollows = "215 list of newsgroups follows" ) diff --git a/internal/server/handler.go b/internal/server/handler.go new file mode 100644 index 0000000..a8448d8 --- /dev/null +++ b/internal/server/handler.go @@ -0,0 +1,115 @@ +package server + +import ( + "fmt" + "github.com/ChronosX88/yans/internal/models" + "github.com/ChronosX88/yans/internal/protocol" + "github.com/jmoiron/sqlx" + "strings" + "time" +) + +type Handler struct { + handlers map[string]func(s *Session, arguments []string) error + db *sqlx.DB +} + +func NewHandler(db *sqlx.DB) *Handler { + h := &Handler{} + h.db = db + h.handlers = map[string]func(s *Session, arguments []string) error{ + protocol.CommandCapabilities: h.handleCapabilities, + protocol.CommandDate: h.handleDate, + protocol.CommandQuit: h.handleQuit, + protocol.CommandList: h.handleList, + } + return h +} + +func (h *Handler) handleCapabilities(s *Session, arguments []string) error { + return s.tconn.PrintfLine(Capabilities.String()) +} + +func (h *Handler) handleDate(s *Session, arguments []string) error { + return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405")) +} + +func (h *Handler) handleQuit(s *Session, arguments []string) error { + s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally) + s.conn.Close() + return nil +} + +func (h *Handler) handleList(s *Session, arguments []string) error { + sb := strings.Builder{} + + listType := "" + if len(arguments) != 0 { + listType = arguments[0] + } + + switch listType { + case "": + fallthrough + case "ACTIVE": + { + groups, err := h.listGroups() + if err != nil { + return err + } + sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF)) + for _, v := range groups { + // TODO set high/low mark and posting status to actual values + sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName))) + } + } + case "NEWSGROUPS": + { + groups, err := h.listGroups() + if err != nil { + return err + } + for _, v := range groups { + desc := "" + if v.Description == nil { + desc = "No description" + } else { + desc = *v.Description + } + sb.Write([]byte(fmt.Sprintf("%s %s"+protocol.CRLF, v.GroupName, desc))) + } + } + default: + { + return s.tconn.PrintfLine(protocol.MessageUnknownCommand) + } + } + + sb.Write([]byte(protocol.MultilineEnding)) + + return s.tconn.PrintfLine(sb.String()) +} + +func (h *Handler) Handle(s *Session, message string) error { + splittedMessage := strings.Split(message, " ") + for i, v := range splittedMessage { + splittedMessage[i] = strings.TrimSpace(v) + } + cmdName := splittedMessage[0] + handler, ok := h.handlers[cmdName] + if !ok { + return s.tconn.PrintfLine(protocol.MessageUnknownCommand) + } + return handler(s, splittedMessage[1:]) +} + +// TODO Refactor to "storage backend" entity +func (h *Handler) listGroups() ([]models.Group, error) { + var groups []models.Group + return groups, h.db.Select(&groups, "SELECT * FROM groups") +} + +func (h *Handler) getArticlesCount(g models.Group) (int, error) { + var count int + return count, h.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID) +} diff --git a/internal/server/nntp_server.go b/internal/server/nntp_server.go index bb4fbda..e82eca0 100644 --- a/internal/server/nntp_server.go +++ b/internal/server/nntp_server.go @@ -6,17 +6,14 @@ import ( "github.com/ChronosX88/yans/internal" "github.com/ChronosX88/yans/internal/common" "github.com/ChronosX88/yans/internal/config" - "github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/protocol" + "github.com/google/uuid" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" "github.com/pressly/goose/v3" - "io" "log" "net" - "net/textproto" - "strings" - "time" + "sync" ) var ( @@ -36,6 +33,9 @@ type NNTPServer struct { port int db *sqlx.DB + + sessionPool map[string]*Session + sessionPoolMutex sync.Mutex } func NewNNTPServer(cfg config.Config) (*NNTPServer, error) { @@ -55,10 +55,11 @@ func NewNNTPServer(cfg config.Config) (*NNTPServer, error) { ctx, cancel := context.WithCancel(context.Background()) ns := &NNTPServer{ - ctx: ctx, - cancelFunc: cancel, - port: cfg.Port, - db: db, + ctx: ctx, + cancelFunc: cancel, + port: cfg.Port, + db: db, + sessionPool: map[string]*Session{}, } return ns, nil } @@ -81,7 +82,30 @@ func (ns *NNTPServer) Start() error { log.Println(err) } log.Printf("Client %s has connected!", conn.RemoteAddr().String()) - go ns.handleNewConnection(ctx, conn) + + id, _ := uuid.NewUUID() + closed := make(chan bool) + session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.db)) + ns.sessionPoolMutex.Lock() + ns.sessionPool[id.String()] = session + ns.sessionPoolMutex.Unlock() + go func(ctx context.Context, id string, closed chan bool) { + for { + select { + case <-ctx.Done(): + break + case _, ok := <-closed: + { + if !ok { + ns.sessionPoolMutex.Lock() + delete(ns.sessionPool, id) + ns.sessionPoolMutex.Unlock() + return + } + } + } + } + }(ctx, id.String(), closed) } } } @@ -90,133 +114,6 @@ func (ns *NNTPServer) Start() error { return nil } -func (ns *NNTPServer) handleNewConnection(ctx context.Context, conn net.Conn) { - _, err := conn.Write([]byte(protocol.MessageNNTPServiceReadyPostingProhibited + protocol.CRLF)) - if err != nil { - log.Print(err) - conn.Close() - return - } - - tconn := textproto.NewConn(conn) - for { - select { - case <-ctx.Done(): - break - default: - { - message, err := tconn.ReadLine() - if err != nil { - if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed { - log.Printf("Client %s has diconnected!", conn.RemoteAddr().String()) - } else { - log.Print(err) - conn.Close() - } - return - } - log.Printf("Received message from %s: %s", conn.RemoteAddr().String(), string(message)) - err = ns.handleMessage(tconn, message) - if err != nil { - log.Print(err) - conn.Close() - return - } - } - } - } -} - -func (ns *NNTPServer) handleMessage(conn *textproto.Conn, msg string) error { - splittedMessage := strings.Split(msg, " ") - command := splittedMessage[0] - - reply := "" - quit := false - - switch command { - case protocol.CommandCapabilities: - { - reply = Capabilities.String() - break - } - case protocol.CommandDate: - { - reply = fmt.Sprintf("111 %s", time.Now().UTC().Format("20060102150405")) - break - } - case protocol.CommandQuit: - { - reply = protocol.MessageNNTPServiceExitsNormally - quit = true - break - } - case protocol.CommandMode: - { - if splittedMessage[1] == "READER" { - // TODO actually switch current conn to reader mode - reply = protocol.MessageReaderModePostingProhibited - } else { - reply = protocol.MessageUnknownCommand - } - break - } - case protocol.CommandList: - { - groups, err := ns.listGroups() - if err != nil { - reply = protocol.MessageErrorHappened + err.Error() - log.Println(err) - } - sb := strings.Builder{} - sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF)) - if len(splittedMessage) == 1 || splittedMessage[1] == "ACTIVE" { - for _, v := range groups { - // TODO set high/low mark and posting status to actual values - sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName))) - } - } else if splittedMessage[1] == "NEWSGROUPS" { - for _, v := range groups { - desc := "" - if v.Description == nil { - desc = "No description" - } else { - desc = *v.Description - } - sb.Write([]byte(fmt.Sprintf("%s %s\r\n", v.GroupName, desc))) - } - } else { - reply = protocol.MessageUnknownCommand - break - } - - sb.Write([]byte(".")) - reply = sb.String() - } - default: - { - reply = protocol.MessageUnknownCommand - break - } - } - - err := conn.PrintfLine(reply) - if quit { - conn.Close() - } - return err -} - -func (ns *NNTPServer) listGroups() ([]models.Group, error) { - var groups []models.Group - return groups, ns.db.Select(&groups, "SELECT * FROM groups") -} - -func (ns *NNTPServer) getArticlesCount(g models.Group) (int, error) { - var count int - return count, ns.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID) -} - func (ns *NNTPServer) Stop() { ns.cancelFunc() } diff --git a/internal/server/session.go b/internal/server/session.go new file mode 100644 index 0000000..1b78d89 --- /dev/null +++ b/internal/server/session.go @@ -0,0 +1,86 @@ +package server + +import ( + "context" + "github.com/ChronosX88/yans/internal/protocol" + "io" + "log" + "net" + "net/textproto" +) + +type Session struct { + ctx context.Context + capabilities protocol.Capabilities + conn net.Conn + tconn *textproto.Conn + id string + closed chan<- bool + h *Handler +} + +func NewSession(ctx context.Context, conn net.Conn, caps protocol.Capabilities, id string, closed chan<- bool, handler *Handler) (*Session, error) { + var err error + defer func() { + if err != nil { + conn.Close() + close(closed) + } + }() + + tconn := textproto.NewConn(conn) + s := &Session{ + ctx: ctx, + conn: conn, + tconn: tconn, + capabilities: caps, + id: id, + closed: closed, + h: handler, + } + + go s.loop() + + return s, nil +} + +func (s *Session) loop() { + defer func() { + close(s.closed) + }() + + err := s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // by default access mode is read-only + if err != nil { + s.conn.Close() + return + } + + for { + select { + case <-s.ctx.Done(): + break + default: + { + message, err := s.tconn.ReadLine() + if err != nil { + if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed { + log.Printf("Client %s has diconnected!", s.conn.RemoteAddr().String()) + } else { + log.Print(err) + s.conn.Close() + } + return + } + log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging + err = s.h.Handle(s, message) + if err != nil { + log.Print(err) + s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error()) + s.conn.Close() + return + } + } + } + } + +}