Implement command pipelining, finish LIST ACTIVE command implementation

This commit is contained in:
ChronosX88 2022-01-20 00:44:27 +03:00
parent 99a950181d
commit dda82b7916
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
8 changed files with 59 additions and 17 deletions

View File

@ -14,7 +14,7 @@
#### Commands #### Commands
- :heavy_check_mark: `CAPABILITIES` - :heavy_check_mark: `CAPABILITIES`
- :heavy_check_mark: `DATE` - :heavy_check_mark: `DATE`
- :construction: `LIST` - :heavy_check_mark: `LIST ACTIVE`
- :heavy_check_mark: `LIST NEWSGROUPS` - :heavy_check_mark: `LIST NEWSGROUPS`
- :heavy_check_mark: `MODE READER` - :heavy_check_mark: `MODE READER`
- :heavy_check_mark: `QUIT` - :heavy_check_mark: `QUIT`

View File

@ -3,10 +3,11 @@
CREATE TABLE IF NOT EXISTS groups( CREATE TABLE IF NOT EXISTS groups(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
group_name TEXT UNIQUE NOT NULL, group_name TEXT UNIQUE NOT NULL,
description TEXT description TEXT,
created_at UNSIGNED BIG INT NOT NULL
); );
CREATE TABLE IF NOT EXISTS articles( CREATE TABLE IF NOT EXISTS articles(
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY AUTOINCREMENT,
date INTEGER NOT NULL, date INTEGER NOT NULL,
path TEXT, path TEXT,
reply_to TEXT, reply_to TEXT,
@ -15,7 +16,7 @@ CREATE TABLE IF NOT EXISTS articles(
body TEXT NOT NULL body TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS articles_to_groups( CREATE TABLE IF NOT EXISTS articles_to_groups(
article_id TEXT NOT NULL, article_id INTEGER NOT NULL,
group_id INTEGER NOT NULL, group_id INTEGER NOT NULL,
FOREIGN KEY (article_id) REFERENCES articles(id) ON DELETE CASCADE, FOREIGN KEY (article_id) REFERENCES articles(id) ON DELETE CASCADE,
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE

View File

@ -43,5 +43,15 @@ func (sb *SQLiteBackend) ListGroups() ([]models.Group, error) {
func (sb *SQLiteBackend) GetArticlesCount(g models.Group) (int, error) { func (sb *SQLiteBackend) GetArticlesCount(g models.Group) (int, error) {
var count int var count int
return count, sb.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID) return count, sb.db.Get(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
}
func (sb *SQLiteBackend) GetGroupHighWaterMark(g models.Group) (int, error) {
var waterMark int
return waterMark, sb.db.Get(&waterMark, "SELECT article_id FROM articles_to_groups WHERE group_id = ? ORDER BY article_id DESC LIMIT 1", g.ID)
}
func (sb *SQLiteBackend) GetGroupLowWaterMark(g models.Group) (int, error) {
var waterMark int
return waterMark, sb.db.Get(&waterMark, "SELECT article_id FROM articles_to_groups WHERE group_id = ? ORDER BY article_id LIMIT 1", g.ID)
} }

View File

@ -9,4 +9,6 @@ const (
type StorageBackend interface { type StorageBackend interface {
ListGroups() ([]models.Group, error) ListGroups() ([]models.Group, error)
GetArticlesCount(g models.Group) (int, error) GetArticlesCount(g models.Group) (int, error)
GetGroupLowWaterMark(g models.Group) (int, error)
GetGroupHighWaterMark(g models.Group) (int, error)
} }

View File

@ -4,4 +4,5 @@ type Group struct {
ID int `db:"id"` ID int `db:"id"`
GroupName string `db:"group_name"` GroupName string `db:"group_name"`
Description *string `db:"description"` Description *string `db:"description"`
CreatedAt uint64 `db:"created_at"`
} }

View File

@ -29,7 +29,7 @@ const (
const ( const (
MessageNNTPServiceReadyPostingProhibited = "201 YANS NNTP Service Ready, posting prohibited" MessageNNTPServiceReadyPostingProhibited = "201 YANS NNTP Service Ready, posting prohibited"
MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited" MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited"
MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally" MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally, bye!"
MessageUnknownCommand = "500 Unknown command" MessageUnknownCommand = "500 Unknown command"
MessageErrorHappened = "403 Failed to process command:" MessageErrorHappened = "403 Failed to process command:"
MessageListOfNewsgroupsFollows = "215 list of newsgroups follows" MessageListOfNewsgroupsFollows = "215 list of newsgroups follows"

View File

@ -9,14 +9,14 @@ import (
) )
type Handler struct { type Handler struct {
handlers map[string]func(s *Session, arguments []string) error handlers map[string]func(s *Session, arguments []string, id uint) error
backend backend.StorageBackend backend backend.StorageBackend
} }
func NewHandler(b backend.StorageBackend) *Handler { func NewHandler(b backend.StorageBackend) *Handler {
h := &Handler{} h := &Handler{}
h.backend = b h.backend = b
h.handlers = map[string]func(s *Session, arguments []string) error{ h.handlers = map[string]func(s *Session, arguments []string, id uint) error{
protocol.CommandCapabilities: h.handleCapabilities, protocol.CommandCapabilities: h.handleCapabilities,
protocol.CommandDate: h.handleDate, protocol.CommandDate: h.handleDate,
protocol.CommandQuit: h.handleQuit, protocol.CommandQuit: h.handleQuit,
@ -26,21 +26,25 @@ func NewHandler(b backend.StorageBackend) *Handler {
return h return h
} }
func (h *Handler) handleCapabilities(s *Session, arguments []string) error { func (h *Handler) handleCapabilities(s *Session, arguments []string, id uint) error {
s.tconn.StartResponse(id)
defer s.tconn.EndResponse(id)
return s.tconn.PrintfLine(s.capabilities.String()) return s.tconn.PrintfLine(s.capabilities.String())
} }
func (h *Handler) handleDate(s *Session, arguments []string) error { func (h *Handler) handleDate(s *Session, arguments []string, id uint) error {
s.tconn.StartResponse(id)
defer s.tconn.EndResponse(id)
return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405")) return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405"))
} }
func (h *Handler) handleQuit(s *Session, arguments []string) error { func (h *Handler) handleQuit(s *Session, arguments []string, id uint) error {
s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally) s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally)
s.conn.Close() s.conn.Close()
return nil return nil
} }
func (h *Handler) handleList(s *Session, arguments []string) error { func (h *Handler) handleList(s *Session, arguments []string, id uint) error {
sb := strings.Builder{} sb := strings.Builder{}
listType := "" listType := ""
@ -48,6 +52,9 @@ func (h *Handler) handleList(s *Session, arguments []string) error {
listType = arguments[0] listType = arguments[0]
} }
s.tconn.StartResponse(id)
defer s.tconn.EndResponse(id)
switch listType { switch listType {
case "": case "":
fallthrough fallthrough
@ -60,7 +67,23 @@ func (h *Handler) handleList(s *Session, arguments []string) error {
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF)) sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
for _, v := range groups { for _, v := range groups {
// TODO set high/low mark and posting status to actual values // TODO set high/low mark and posting status to actual values
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName))) c, err := h.backend.GetArticlesCount(v)
if err != nil {
return err
}
if c > 0 {
highWaterMark, err := h.backend.GetGroupHighWaterMark(v)
if err != nil {
return err
}
lowWaterMark, err := h.backend.GetGroupLowWaterMark(v)
if err != nil {
return err
}
sb.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark)))
} else {
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
}
} }
} }
case "NEWSGROUPS": case "NEWSGROUPS":
@ -90,7 +113,7 @@ func (h *Handler) handleList(s *Session, arguments []string) error {
return s.tconn.PrintfLine(sb.String()) return s.tconn.PrintfLine(sb.String())
} }
func (h *Handler) handleModeReader(s *Session, arguments []string) error { func (h *Handler) handleModeReader(s *Session, arguments []string, id uint) error {
if len(arguments) == 0 || arguments[0] != "READER" { if len(arguments) == 0 || arguments[0] != "READER" {
return s.tconn.PrintfLine(protocol.MessageSyntaxError) return s.tconn.PrintfLine(protocol.MessageSyntaxError)
} }
@ -104,7 +127,7 @@ func (h *Handler) handleModeReader(s *Session, arguments []string) error {
return s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // TODO vary on auth status return s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // TODO vary on auth status
} }
func (h *Handler) Handle(s *Session, message string) error { func (h *Handler) Handle(s *Session, message string, id uint) error {
splittedMessage := strings.Split(message, " ") splittedMessage := strings.Split(message, " ")
for i, v := range splittedMessage { for i, v := range splittedMessage {
splittedMessage[i] = strings.TrimSpace(v) splittedMessage[i] = strings.TrimSpace(v)
@ -112,7 +135,9 @@ func (h *Handler) Handle(s *Session, message string) error {
cmdName := splittedMessage[0] cmdName := splittedMessage[0]
handler, ok := h.handlers[cmdName] handler, ok := h.handlers[cmdName]
if !ok { if !ok {
s.tconn.StartResponse(id)
defer s.tconn.EndResponse(id)
return s.tconn.PrintfLine(protocol.MessageUnknownCommand) return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
} }
return handler(s, splittedMessage[1:]) return handler(s, splittedMessage[1:], id)
} }

View File

@ -70,6 +70,8 @@ func (s *Session) loop() {
break break
default: default:
{ {
id := s.tconn.Next()
s.tconn.StartRequest(id)
message, err := s.tconn.ReadLine() message, err := s.tconn.ReadLine()
if err != nil { if err != nil {
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed { if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
@ -80,8 +82,9 @@ func (s *Session) loop() {
} }
return return
} }
s.tconn.EndRequest(id)
log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging
err = s.h.Handle(s, message) err = s.h.Handle(s, message, id)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error()) s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error())