diff --git a/README.md b/README.md index 1fc4dbb..6073502 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ #### Commands - :heavy_check_mark: `CAPABILITIES` - :heavy_check_mark: `DATE` -- :construction: `LIST` +- :heavy_check_mark: `LIST ACTIVE` - :heavy_check_mark: `LIST NEWSGROUPS` - :heavy_check_mark: `MODE READER` - :heavy_check_mark: `QUIT` diff --git a/internal/backend/sqlite/migrations/001_init_schema.sql b/internal/backend/sqlite/migrations/001_init_schema.sql index 5cf4eee..add349e 100644 --- a/internal/backend/sqlite/migrations/001_init_schema.sql +++ b/internal/backend/sqlite/migrations/001_init_schema.sql @@ -3,10 +3,11 @@ CREATE TABLE IF NOT EXISTS groups( id INTEGER PRIMARY KEY AUTOINCREMENT, group_name TEXT UNIQUE NOT NULL, - description TEXT + description TEXT, + created_at UNSIGNED BIG INT NOT NULL ); CREATE TABLE IF NOT EXISTS articles( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY AUTOINCREMENT, date INTEGER NOT NULL, path TEXT, reply_to TEXT, @@ -15,7 +16,7 @@ CREATE TABLE IF NOT EXISTS articles( body TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS articles_to_groups( - article_id TEXT NOT NULL, + article_id INTEGER NOT NULL, group_id INTEGER NOT NULL, FOREIGN KEY (article_id) REFERENCES articles(id) ON DELETE CASCADE, FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE diff --git a/internal/backend/sqlite/sqlite.go b/internal/backend/sqlite/sqlite.go index fe98f21..a77b280 100644 --- a/internal/backend/sqlite/sqlite.go +++ b/internal/backend/sqlite/sqlite.go @@ -43,5 +43,15 @@ func (sb *SQLiteBackend) ListGroups() ([]models.Group, error) { func (sb *SQLiteBackend) GetArticlesCount(g models.Group) (int, error) { var count int - return count, sb.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID) + return count, sb.db.Get(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID) +} + +func (sb *SQLiteBackend) GetGroupHighWaterMark(g models.Group) (int, error) { + var waterMark int + return waterMark, sb.db.Get(&waterMark, "SELECT article_id FROM articles_to_groups WHERE group_id = ? ORDER BY article_id DESC LIMIT 1", g.ID) +} + +func (sb *SQLiteBackend) GetGroupLowWaterMark(g models.Group) (int, error) { + var waterMark int + return waterMark, sb.db.Get(&waterMark, "SELECT article_id FROM articles_to_groups WHERE group_id = ? ORDER BY article_id LIMIT 1", g.ID) } diff --git a/internal/backend/storage_backend.go b/internal/backend/storage_backend.go index 2b82015..d1cda95 100644 --- a/internal/backend/storage_backend.go +++ b/internal/backend/storage_backend.go @@ -9,4 +9,6 @@ const ( type StorageBackend interface { ListGroups() ([]models.Group, error) GetArticlesCount(g models.Group) (int, error) + GetGroupLowWaterMark(g models.Group) (int, error) + GetGroupHighWaterMark(g models.Group) (int, error) } diff --git a/internal/models/group.go b/internal/models/group.go index b26baca..fd80d5d 100644 --- a/internal/models/group.go +++ b/internal/models/group.go @@ -4,4 +4,5 @@ type Group struct { ID int `db:"id"` GroupName string `db:"group_name"` Description *string `db:"description"` + CreatedAt uint64 `db:"created_at"` } diff --git a/internal/protocol/constants.go b/internal/protocol/constants.go index 28c7e17..d5b4414 100644 --- a/internal/protocol/constants.go +++ b/internal/protocol/constants.go @@ -29,7 +29,7 @@ const ( const ( MessageNNTPServiceReadyPostingProhibited = "201 YANS NNTP Service Ready, posting prohibited" MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited" - MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally" + MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally, bye!" MessageUnknownCommand = "500 Unknown command" MessageErrorHappened = "403 Failed to process command:" MessageListOfNewsgroupsFollows = "215 list of newsgroups follows" diff --git a/internal/server/handler.go b/internal/server/handler.go index 788de1e..a8435f1 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -9,14 +9,14 @@ import ( ) type Handler struct { - handlers map[string]func(s *Session, arguments []string) error + handlers map[string]func(s *Session, arguments []string, id uint) error backend backend.StorageBackend } func NewHandler(b backend.StorageBackend) *Handler { h := &Handler{} h.backend = b - h.handlers = map[string]func(s *Session, arguments []string) error{ + h.handlers = map[string]func(s *Session, arguments []string, id uint) error{ protocol.CommandCapabilities: h.handleCapabilities, protocol.CommandDate: h.handleDate, protocol.CommandQuit: h.handleQuit, @@ -26,21 +26,25 @@ func NewHandler(b backend.StorageBackend) *Handler { return h } -func (h *Handler) handleCapabilities(s *Session, arguments []string) error { +func (h *Handler) handleCapabilities(s *Session, arguments []string, id uint) error { + s.tconn.StartResponse(id) + defer s.tconn.EndResponse(id) return s.tconn.PrintfLine(s.capabilities.String()) } -func (h *Handler) handleDate(s *Session, arguments []string) error { +func (h *Handler) handleDate(s *Session, arguments []string, id uint) error { + s.tconn.StartResponse(id) + defer s.tconn.EndResponse(id) return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405")) } -func (h *Handler) handleQuit(s *Session, arguments []string) error { +func (h *Handler) handleQuit(s *Session, arguments []string, id uint) error { s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally) s.conn.Close() return nil } -func (h *Handler) handleList(s *Session, arguments []string) error { +func (h *Handler) handleList(s *Session, arguments []string, id uint) error { sb := strings.Builder{} listType := "" @@ -48,6 +52,9 @@ func (h *Handler) handleList(s *Session, arguments []string) error { listType = arguments[0] } + s.tconn.StartResponse(id) + defer s.tconn.EndResponse(id) + switch listType { case "": fallthrough @@ -60,7 +67,23 @@ func (h *Handler) handleList(s *Session, arguments []string) error { sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF)) for _, v := range groups { // TODO set high/low mark and posting status to actual values - sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName))) + c, err := h.backend.GetArticlesCount(v) + if err != nil { + return err + } + if c > 0 { + highWaterMark, err := h.backend.GetGroupHighWaterMark(v) + if err != nil { + return err + } + lowWaterMark, err := h.backend.GetGroupLowWaterMark(v) + if err != nil { + return err + } + sb.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark))) + } else { + sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName))) + } } } case "NEWSGROUPS": @@ -90,7 +113,7 @@ func (h *Handler) handleList(s *Session, arguments []string) error { return s.tconn.PrintfLine(sb.String()) } -func (h *Handler) handleModeReader(s *Session, arguments []string) error { +func (h *Handler) handleModeReader(s *Session, arguments []string, id uint) error { if len(arguments) == 0 || arguments[0] != "READER" { return s.tconn.PrintfLine(protocol.MessageSyntaxError) } @@ -104,7 +127,7 @@ func (h *Handler) handleModeReader(s *Session, arguments []string) error { return s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // TODO vary on auth status } -func (h *Handler) Handle(s *Session, message string) error { +func (h *Handler) Handle(s *Session, message string, id uint) error { splittedMessage := strings.Split(message, " ") for i, v := range splittedMessage { splittedMessage[i] = strings.TrimSpace(v) @@ -112,7 +135,9 @@ func (h *Handler) Handle(s *Session, message string) error { cmdName := splittedMessage[0] handler, ok := h.handlers[cmdName] if !ok { + s.tconn.StartResponse(id) + defer s.tconn.EndResponse(id) return s.tconn.PrintfLine(protocol.MessageUnknownCommand) } - return handler(s, splittedMessage[1:]) + return handler(s, splittedMessage[1:], id) } diff --git a/internal/server/session.go b/internal/server/session.go index ce9b159..533609e 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -70,6 +70,8 @@ func (s *Session) loop() { break default: { + id := s.tconn.Next() + s.tconn.StartRequest(id) message, err := s.tconn.ReadLine() if err != nil { if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed { @@ -80,8 +82,9 @@ func (s *Session) loop() { } return } + s.tconn.EndRequest(id) log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging - err = s.h.Handle(s, message) + err = s.h.Handle(s, message, id) if err != nil { log.Print(err) s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error())