From 10320d169a118bc07e292cc6cd46fa8d4720a9ad Mon Sep 17 00:00:00 2001 From: ChronosX88 Date: Thu, 3 Feb 2022 20:39:08 +0300 Subject: [PATCH] Implement LISTGROUP command --- README.md | 2 +- internal/backend/sqlite/sqlite.go | 32 +++++++++-- internal/backend/storage_backend.go | 7 +-- internal/protocol/constants.go | 2 +- internal/server/handler.go | 83 +++++++++++++++++++++++++---- internal/utils/range.go | 27 ++++++++++ 6 files changed, 135 insertions(+), 18 deletions(-) create mode 100644 internal/utils/range.go diff --git a/README.md b/README.md index 57bc6cf..f96f733 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ - :x: `HELP` - :x: `IHAVE` - :x: `LAST` -- :x: `LISTGROUP` +- :heavy_check_mark: `LISTGROUP` - :heavy_check_mark: `NEWGROUPS` - :x: `NEWNEWS` - :x: `NEXT` diff --git a/internal/backend/sqlite/sqlite.go b/internal/backend/sqlite/sqlite.go index 05931f2..d4484bf 100644 --- a/internal/backend/sqlite/sqlite.go +++ b/internal/backend/sqlite/sqlite.go @@ -72,17 +72,17 @@ func (sb *SQLiteBackend) ListGroupsByPattern(pattern string) ([]models.Group, er return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE group_name REGEXP ?", r.String()) } -func (sb *SQLiteBackend) GetArticlesCount(g models.Group) (int, error) { +func (sb *SQLiteBackend) GetArticlesCount(g *models.Group) (int, error) { var count int return count, sb.db.Get(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID) } -func (sb *SQLiteBackend) GetGroupHighWaterMark(g models.Group) (int, error) { +func (sb *SQLiteBackend) GetGroupHighWaterMark(g *models.Group) (int, error) { var waterMark int return waterMark, sb.db.Get(&waterMark, "SELECT article_id FROM articles_to_groups WHERE group_id = ? ORDER BY article_id DESC LIMIT 1", g.ID) } -func (sb *SQLiteBackend) GetGroupLowWaterMark(g models.Group) (int, error) { +func (sb *SQLiteBackend) GetGroupLowWaterMark(g *models.Group) (int, error) { var waterMark int return waterMark, sb.db.Get(&waterMark, "SELECT article_id FROM articles_to_groups WHERE group_id = ? ORDER BY article_id LIMIT 1", g.ID) } @@ -134,3 +134,29 @@ func (sb *SQLiteBackend) GetArticle(messageID string) (models.Article, error) { } return a, json.Unmarshal([]byte(a.HeaderRaw), &a.Header) } + +func (sb *SQLiteBackend) GetArticleNumbers(g *models.Group, low, high int64) ([]int64, error) { + var numbers []int64 + + if high == 0 && low == 0 { + if err := sb.db.Select(&numbers, "SELECT article_id FROM articles_to_groups WHERE group_id = ?", g.ID); err != nil { + return nil, err + } + } else if low == -1 && high != 0 { + if err := sb.db.Select(&numbers, "SELECT article_id FROM articles_to_groups WHERE group_id = ? AND article_id = ?", g.ID, high); err != nil { + return nil, err + } + } else if low != 0 && high == -1 { + if err := sb.db.Select(&numbers, "SELECT article_id FROM articles_to_groups WHERE group_id = ? AND article_id > ?", g.ID, low); err != nil { + return nil, err + } + } else if low == -1 && high == -1 { + return nil, nil + } else { + if err := sb.db.Select(&numbers, "SELECT article_id FROM articles_to_groups WHERE group_id = ? AND article_id > ? AND article_id < ?", g.ID, low, high); err != nil { + return nil, err + } + } + + return numbers, nil +} diff --git a/internal/backend/storage_backend.go b/internal/backend/storage_backend.go index 101ded3..c4b02bd 100644 --- a/internal/backend/storage_backend.go +++ b/internal/backend/storage_backend.go @@ -11,9 +11,10 @@ type StorageBackend interface { ListGroupsByPattern(pattern string) ([]models.Group, error) GetGroup(groupName string) (models.Group, error) GetNewGroupsSince(timestamp int64) ([]models.Group, error) - GetArticlesCount(g models.Group) (int, error) - GetGroupLowWaterMark(g models.Group) (int, error) - GetGroupHighWaterMark(g models.Group) (int, error) + GetArticlesCount(g *models.Group) (int, error) + GetGroupLowWaterMark(g *models.Group) (int, error) + GetGroupHighWaterMark(g *models.Group) (int, error) SaveArticle(article models.Article, groups []string) error GetArticle(messageID string) (models.Article, error) + GetArticleNumbers(g *models.Group, low, high int64) ([]int64, error) } diff --git a/internal/protocol/constants.go b/internal/protocol/constants.go index c42049a..3e0fcfe 100644 --- a/internal/protocol/constants.go +++ b/internal/protocol/constants.go @@ -55,6 +55,7 @@ const ( CommandGroup = "GROUP" CommandNewGroups = "NEWGROUPS" CommandPost = "POST" + CommandListGroup = "LISTGROUP" ) const ( @@ -74,7 +75,6 @@ const ( MessageNNTPServiceReadyPostingProhibited = "201 YANS NNTP Service Ready, posting prohibited" MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited" MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally, bye!" - MessageUnknownCommand = "500 Unknown command" MessageErrorHappened = "403 Failed to process command:" MessageListOfNewsgroupsFollows = "215 list of newsgroups follows" MessageNoSuchGroup = "411 No such newsgroup" diff --git a/internal/server/handler.go b/internal/server/handler.go index 5e459f4..9208f69 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -7,9 +7,11 @@ import ( "github.com/ChronosX88/yans/internal/backend" "github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/protocol" + "github.com/ChronosX88/yans/internal/utils" "github.com/google/uuid" "io" "net/mail" + "strconv" "strings" "time" ) @@ -32,6 +34,7 @@ func NewHandler(b backend.StorageBackend, serverDomain string) *Handler { protocol.CommandGroup: h.handleGroup, protocol.CommandNewGroups: h.handleNewGroups, protocol.CommandPost: h.handlePost, + protocol.CommandListGroup: h.handleListgroup, } h.serverDomain = serverDomain return h @@ -85,16 +88,16 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error { sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF)) for _, v := range groups { // TODO set actual post permission status - c, err := h.backend.GetArticlesCount(v) + c, err := h.backend.GetArticlesCount(&v) if err != nil { return err } if c > 0 { - highWaterMark, err := h.backend.GetGroupHighWaterMark(v) + highWaterMark, err := h.backend.GetGroupHighWaterMark(&v) if err != nil { return err } - lowWaterMark, err := h.backend.GetGroupLowWaterMark(v) + lowWaterMark, err := h.backend.GetGroupLowWaterMark(&v) if err != nil { return err } @@ -169,15 +172,15 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error { return err } } - highWaterMark, err := h.backend.GetGroupHighWaterMark(g) + highWaterMark, err := h.backend.GetGroupHighWaterMark(&g) if err != nil && err != sql.ErrNoRows { return err } - lowWaterMark, err := h.backend.GetGroupLowWaterMark(g) + lowWaterMark, err := h.backend.GetGroupLowWaterMark(&g) if err != nil && err != sql.ErrNoRows { return err } - articlesCount, err := h.backend.GetArticlesCount(g) + articlesCount, err := h.backend.GetArticlesCount(&g) if err != nil && err != sql.ErrNoRows { return err } @@ -230,16 +233,16 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error dw.Write([]byte(protocol.NNTPResponse{Code: 231, Message: "list of new newsgroups follows"}.String() + protocol.CRLF)) for _, v := range g { // TODO set actual post permission status - c, err := h.backend.GetArticlesCount(v) + c, err := h.backend.GetArticlesCount(&v) if err != nil { return err } if c > 0 { - highWaterMark, err := h.backend.GetGroupHighWaterMark(v) + highWaterMark, err := h.backend.GetGroupHighWaterMark(&v) if err != nil { return err } - lowWaterMark, err := h.backend.GetGroupLowWaterMark(v) + lowWaterMark, err := h.backend.GetGroupLowWaterMark(&v) if err != nil { return err } @@ -321,6 +324,66 @@ func (h *Handler) handlePost(s *Session, arguments []string, id uint) error { return s.tconn.PrintfLine(protocol.MessageArticleReceived) } +func (h *Handler) handleListgroup(s *Session, arguments []string, id uint) error { + s.tconn.StartResponse(id) + defer s.tconn.EndResponse(id) + + currentGroup := s.currentGroup + var low, high int64 + if len(arguments) == 1 { + g, err := h.backend.GetGroup(arguments[0]) + if err != nil { + return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 411, Message: "No such newsgroup"}.String()) + } + currentGroup = &g + } else if len(arguments) == 2 { + g, err := h.backend.GetGroup(arguments[0]) + if err != nil { + return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 411, Message: "No such newsgroup"}.String()) + } + currentGroup = &g + + low, high, err = utils.ParseRange(arguments[1]) + if err != nil { + low = 0 + high = 0 + } + if high != -1 && low > high { + low = -1 + high = -1 + } + } + + if currentGroup == nil { + return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 412, Message: "No newsgroup selected"}.String()) + } + + highWaterMark, err := h.backend.GetGroupHighWaterMark(currentGroup) + if err != nil && err != sql.ErrNoRows { + return err + } + lowWaterMark, err := h.backend.GetGroupLowWaterMark(currentGroup) + if err != nil && err != sql.ErrNoRows { + return err + } + articlesCount, err := h.backend.GetArticlesCount(currentGroup) + if err != nil && err != sql.ErrNoRows { + return err + } + + nums, err := h.backend.GetArticleNumbers(currentGroup, low, high) + if err != nil && err != sql.ErrNoRows { + return err + } + + dw := s.tconn.DotWriter() + dw.Write([]byte(protocol.NNTPResponse{Code: 211, Message: fmt.Sprintf("%d %d %d %s list follows%s", articlesCount, lowWaterMark, highWaterMark, currentGroup.GroupName, protocol.CRLF)}.String())) + for _, v := range nums { + dw.Write([]byte(strconv.FormatInt(v, 10) + protocol.CRLF)) + } + return dw.Close() +} + func (h *Handler) Handle(s *Session, message string, id uint) error { splittedMessage := strings.Split(message, " ") for i, v := range splittedMessage { @@ -331,7 +394,7 @@ func (h *Handler) Handle(s *Session, message string, id uint) error { if !ok { s.tconn.StartResponse(id) defer s.tconn.EndResponse(id) - return s.tconn.PrintfLine(protocol.MessageUnknownCommand) + return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 500, Message: "Unknown command"}.String()) } return handler(s, splittedMessage[1:], id) } diff --git a/internal/utils/range.go b/internal/utils/range.go new file mode 100644 index 0000000..d9c9fcd --- /dev/null +++ b/internal/utils/range.go @@ -0,0 +1,27 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" +) + +func ParseRange(spec string) (int64, int64, error) { + if spec == "" { + return 0, 0, fmt.Errorf("no range specified") + } + parts := strings.Split(spec, "-") + if len(parts) == 1 { + h, err := strconv.ParseInt(parts[0], 10, 64) + return -1, h, err + } + l, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return 0, 0, err + } + if parts[1] == "" { + return l, -1, nil + } + h, err := strconv.ParseInt(parts[1], 10, 64) + return l, h, err +}