From 163f2bcba66823ed93183fec495cbdd00db3185f Mon Sep 17 00:00:00 2001 From: ChronosX88 Date: Thu, 3 Feb 2022 19:44:08 +0300 Subject: [PATCH] Implement POST command --- README.md | 2 +- config.sample.toml | 1 + .../sqlite/migrations/001_init_schema.sql | 8 +- internal/backend/sqlite/sqlite.go | 43 ++++++- internal/backend/storage_backend.go | 2 + internal/config/config.go | 1 + internal/models/article.go | 81 +++++++++++++ internal/models/group.go | 10 +- internal/protocol/constants.go | 46 +++++++- internal/protocol/nntp_response.go | 12 ++ internal/server/handler.go | 109 +++++++++++++++--- internal/server/nntp_server.go | 2 +- 12 files changed, 287 insertions(+), 30 deletions(-) create mode 100644 internal/models/article.go create mode 100644 internal/protocol/nntp_response.go diff --git a/README.md b/README.md index 044bbaa..57bc6cf 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ - :x: `NEWNEWS` - :x: `NEXT` - :x: `OVER` -- :x: `POST` +- :construction: `POST` - :x: `STAT` ## License diff --git a/config.sample.toml b/config.sample.toml index fe3fe47..84dc02a 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -1,6 +1,7 @@ address = "localhost" port = 1119 backend_type = "sqlite" +domain = "localhost" [sqlite] path = "yans.db" \ No newline at end of file diff --git a/internal/backend/sqlite/migrations/001_init_schema.sql b/internal/backend/sqlite/migrations/001_init_schema.sql index add349e..92d19a4 100644 --- a/internal/backend/sqlite/migrations/001_init_schema.sql +++ b/internal/backend/sqlite/migrations/001_init_schema.sql @@ -4,15 +4,13 @@ CREATE TABLE IF NOT EXISTS groups( id INTEGER PRIMARY KEY AUTOINCREMENT, group_name TEXT UNIQUE NOT NULL, description TEXT, - created_at UNSIGNED BIG INT NOT NULL + created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS articles( id INTEGER PRIMARY KEY AUTOINCREMENT, - date INTEGER NOT NULL, - path TEXT, - reply_to TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + header TEXT, thread TEXT, - subject TEXT NOT NULL, body TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS articles_to_groups( diff --git a/internal/backend/sqlite/sqlite.go b/internal/backend/sqlite/sqlite.go index b07c414..05931f2 100644 --- a/internal/backend/sqlite/sqlite.go +++ b/internal/backend/sqlite/sqlite.go @@ -3,6 +3,8 @@ package sqlite import ( "database/sql" "embed" + "encoding/json" + "fmt" "github.com/ChronosX88/yans/internal/config" "github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/utils" @@ -11,6 +13,7 @@ import ( "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3" "github.com/pressly/goose/v3" + "strings" ) //go:embed migrations/*.sql @@ -91,5 +94,43 @@ func (sb *SQLiteBackend) GetGroup(groupName string) (models.Group, error) { func (sb *SQLiteBackend) GetNewGroupsSince(timestamp int64) ([]models.Group, error) { var groups []models.Group - return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE created_at > ?", timestamp) + return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE created_at > datetime(?, 'unixepoch')", timestamp) +} + +func (sb *SQLiteBackend) SaveArticle(a models.Article, groups []string) error { + res, err := sb.db.Exec("INSERT INTO articles (header, body, thread) VALUES (?, ?, ?)", a.HeaderRaw, a.Body, a.Thread) + articleID, err := res.LastInsertId() + if err != nil { + return err + } + + var groupIDs []int + for _, v := range groups { + v = strings.TrimSpace(v) + g, err := sb.GetGroup(v) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("no such newsgroup") + } else { + return err + } + } + groupIDs = append(groupIDs, g.ID) + } + + for _, v := range groupIDs { + _, err = sb.db.Exec("INSERT INTO articles_to_groups (article_id, group_id) VALUES (?, ?)", articleID, v) + if err != nil { + return err + } + } + return err +} + +func (sb *SQLiteBackend) GetArticle(messageID string) (models.Article, error) { + var a models.Article + if err := sb.db.Get(&a, "SELECT * FROM articles WHERE json_extract(articles.header, '$.Message-ID[0]') = ?", messageID); err != nil { + return a, err + } + return a, json.Unmarshal([]byte(a.HeaderRaw), &a.Header) } diff --git a/internal/backend/storage_backend.go b/internal/backend/storage_backend.go index 112aa5b..101ded3 100644 --- a/internal/backend/storage_backend.go +++ b/internal/backend/storage_backend.go @@ -14,4 +14,6 @@ type StorageBackend interface { GetArticlesCount(g models.Group) (int, error) GetGroupLowWaterMark(g models.Group) (int, error) GetGroupHighWaterMark(g models.Group) (int, error) + SaveArticle(article models.Article, groups []string) error + GetArticle(messageID string) (models.Article, error) } diff --git a/internal/config/config.go b/internal/config/config.go index f9b5aea..0c915ea 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,7 @@ type Config struct { Address string `toml:"address"` Port int `toml:"port"` BackendType string `toml:"backend_type"` + Domain string `toml:"domain"` SQLite SQLiteBackendConfig `toml:"sqlite"` } diff --git a/internal/models/article.go b/internal/models/article.go new file mode 100644 index 0000000..39e04e4 --- /dev/null +++ b/internal/models/article.go @@ -0,0 +1,81 @@ +package models + +import ( + "database/sql" + "net/textproto" + "time" +) + +type Article struct { + ID int `db:"id"` + CreatedAt time.Time `db:"created_at"` + HeaderRaw string `db:"header"` + Header textproto.MIMEHeader `db:"-"` + Body string `db:"body"` + Thread sql.NullString `db:"thread"` +} + +//func ParseArticle(lines []string) (Article, error) { +// article := Article{} +// headerBlock := true +// for _, v := range lines { +// if v == "" { +// headerBlock = false +// } +// +// if headerBlock { +// kv := strings.Split(v, ":") +// if len(kv) < 2 { +// return Article{}, fmt.Errorf("invalid header format") +// } +// +// kv[0] = strings.TrimSpace(kv[0]) +// kv[1] = strings.TrimSpace(kv[1]) +// +// if !protocol.IsMessageHeaderAllowed(kv[0]) { +// return Article{}, fmt.Errorf("invalid header element") +// } +// if kv[1] == "" { +// return Article{}, fmt.Errorf("header value should not be empty") +// } +// +// switch kv[0] { +// case "Archive": +// { +// if kv[1] == "yes" { +// article.Archive = true +// } else { +// article.Archive = false +// } +// } +// case "Injection-Date": +// { +// date, err := mail.ParseDate(kv[1]) +// if err != nil { +// return Article{}, err +// } +// article.InjectionDate = date +// } +// case "Date": +// { +// date, err := mail.ParseDate(kv[1]) +// if err != nil { +// return Article{}, err +// } +// article.Date = date +// } +// case "Expires": +// { +// date, err := mail.ParseDate(kv[1]) +// if err != nil { +// return Article{}, err +// } +// article.Expires = date +// } +// } +// +// } else { +// } +// } +// return article, nil +//} diff --git a/internal/models/group.go b/internal/models/group.go index fd80d5d..cc121ea 100644 --- a/internal/models/group.go +++ b/internal/models/group.go @@ -1,8 +1,10 @@ package models +import "time" + type Group struct { - ID int `db:"id"` - GroupName string `db:"group_name"` - Description *string `db:"description"` - CreatedAt uint64 `db:"created_at"` + ID int `db:"id"` + GroupName string `db:"group_name"` + Description *string `db:"description"` + CreatedAt time.Time `db:"created_at"` } diff --git a/internal/protocol/constants.go b/internal/protocol/constants.go index 015f0cd..c42049a 100644 --- a/internal/protocol/constants.go +++ b/internal/protocol/constants.go @@ -1,5 +1,46 @@ package protocol +var ( + ErrSyntaxError = NNTPResponse{Code: 501, Message: "Syntax Error"} +) + +func IsMessageHeaderAllowed(headerName string) bool { + switch headerName { + case + "Date", + "From", + "Message-ID", + "Newsgroups", + "Path", + "Subject", + "Comments", + "Keywords", + "In-Reply-To", + "Sender", + "MIME-Version", + "Content-Type", + "Content-Transfer-Encoding", + "Content-Disposition", + "Content-Language", + "Approved", + "Archive", + "Control", + "Distribution", + "Expires", + "Followup-To", + "Injection-Date", + "Injection-Info", + "Organization", + "References", + "Summary", + "Supersedes", + "User-Agent", + "Xref": + return true + } + return false +} + const ( CRLF = "\r\n" MultilineEnding = "." @@ -13,6 +54,7 @@ const ( CommandList = "LIST" CommandGroup = "GROUP" CommandNewGroups = "NEWGROUPS" + CommandPost = "POST" ) const ( @@ -35,7 +77,7 @@ const ( MessageUnknownCommand = "500 Unknown command" MessageErrorHappened = "403 Failed to process command:" MessageListOfNewsgroupsFollows = "215 list of newsgroups follows" - MessageNewGroupsListOfNewsgroupsFollows = "231 list of new newsgroups follows" - MessageSyntaxError = "501 Syntax Error" MessageNoSuchGroup = "411 No such newsgroup" + MessageInputArticle = "340 Input article; end with ." + MessageArticleReceived = "240 Article received OK" ) diff --git a/internal/protocol/nntp_response.go b/internal/protocol/nntp_response.go new file mode 100644 index 0000000..64eddcd --- /dev/null +++ b/internal/protocol/nntp_response.go @@ -0,0 +1,12 @@ +package protocol + +import "fmt" + +type NNTPResponse struct { + Code int + Message string +} + +func (nr NNTPResponse) String() string { + return fmt.Sprintf("%d %s", nr.Code, nr.Message) +} diff --git a/internal/server/handler.go b/internal/server/handler.go index 2143567..5e459f4 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,20 +2,25 @@ package server import ( "database/sql" + "encoding/json" "fmt" "github.com/ChronosX88/yans/internal/backend" "github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/protocol" + "github.com/google/uuid" + "io" + "net/mail" "strings" "time" ) type Handler struct { - handlers map[string]func(s *Session, arguments []string, id uint) error - backend backend.StorageBackend + handlers map[string]func(s *Session, arguments []string, id uint) error + backend backend.StorageBackend + serverDomain string } -func NewHandler(b backend.StorageBackend) *Handler { +func NewHandler(b backend.StorageBackend, serverDomain string) *Handler { h := &Handler{} h.backend = b h.handlers = map[string]func(s *Session, arguments []string, id uint) error{ @@ -26,7 +31,9 @@ func NewHandler(b backend.StorageBackend) *Handler { protocol.CommandMode: h.handleModeReader, protocol.CommandGroup: h.handleGroup, protocol.CommandNewGroups: h.handleNewGroups, + protocol.CommandPost: h.handlePost, } + h.serverDomain = serverDomain return h } @@ -123,7 +130,7 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error { } default: { - return s.tconn.PrintfLine(protocol.MessageSyntaxError) + return s.tconn.PrintfLine(protocol.ErrSyntaxError.String()) } } @@ -134,7 +141,7 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error { func (h *Handler) handleModeReader(s *Session, arguments []string, id uint) error { if len(arguments) == 0 || arguments[0] != "READER" { - return s.tconn.PrintfLine(protocol.MessageSyntaxError) + return s.tconn.PrintfLine(protocol.ErrSyntaxError.String()) } (&s.capabilities).Remove(protocol.ModeReaderCapability) @@ -151,7 +158,7 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error { defer s.tconn.EndResponse(id) if len(arguments) == 0 || len(arguments) > 1 { - return s.tconn.PrintfLine(protocol.MessageSyntaxError) + return s.tconn.PrintfLine(protocol.ErrSyntaxError.String()) } g, err := h.backend.GetGroup(arguments[0]) @@ -177,7 +184,10 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error { s.currentGroup = &g - return s.tconn.PrintfLine("211 %d %d %d %s", articlesCount, lowWaterMark, highWaterMark, g.GroupName) + return s.tconn.PrintfLine(protocol.NNTPResponse{ + Code: 211, + Message: fmt.Sprintf("%d %d %d %s", articlesCount, lowWaterMark, highWaterMark, g.GroupName), + }.String()) } func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error { @@ -185,7 +195,7 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error defer s.tconn.EndResponse(id) if len(arguments) < 2 || len(arguments) > 3 { - return s.tconn.PrintfLine(protocol.MessageSyntaxError) + return s.tconn.PrintfLine(protocol.ErrSyntaxError.String()) } dateString := arguments[0] + " " + arguments[1] @@ -208,7 +218,7 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error return err } } else { - return s.tconn.PrintfLine(protocol.MessageSyntaxError) + return s.tconn.PrintfLine(protocol.ErrSyntaxError.String()) } g, err := h.backend.GetNewGroupsSince(date.Unix()) @@ -216,9 +226,8 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error return err } - var sb strings.Builder - - sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF)) + dw := s.tconn.DotWriter() + dw.Write([]byte(protocol.NNTPResponse{Code: 231, Message: "list of new newsgroups follows"}.String() + protocol.CRLF)) for _, v := range g { // TODO set actual post permission status c, err := h.backend.GetArticlesCount(v) @@ -234,14 +243,82 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error if err != nil { return err } - sb.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark))) + dw.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark))) } else { - sb.Write([]byte(fmt.Sprintf("%s 0 1 n"+protocol.CRLF, v.GroupName))) + dw.Write([]byte(fmt.Sprintf("%s 0 1 n"+protocol.CRLF, v.GroupName))) } } - sb.Write([]byte(protocol.MultilineEnding)) - return s.tconn.PrintfLine(sb.String()) + return dw.Close() +} + +func (h *Handler) handlePost(s *Session, arguments []string, id uint) error { + s.tconn.StartResponse(id) + defer s.tconn.EndResponse(id) + + if len(arguments) != 0 { + return s.tconn.PrintfLine(protocol.ErrSyntaxError.String()) + } + + if err := s.tconn.PrintfLine(protocol.MessageInputArticle); err != nil { + return err + } + + headers, err := s.tconn.ReadMIMEHeader() + if err != nil { + return err + } + if err != nil { + return err + } + + // generate message id + messageID := fmt.Sprintf("<%s@%s>", uuid.New().String(), h.serverDomain) + headers["Message-ID"] = []string{messageID} + + headerJson, err := json.Marshal(headers) + if err != nil { + return err + } + + a := models.Article{} + a.HeaderRaw = string(headerJson) + a.Header = headers + + dr := s.tconn.DotReader() + // TODO handle multipart message + body, err := io.ReadAll(dr) + if err != nil { + return err + } + a.Body = string(body) + + // set thread property + if headers.Get("In-Reply-To") != "" { + parentMessage, err := h.backend.GetArticle(headers.Get("In-Reply-To")) + if err != nil { + if err == sql.ErrNoRows { + return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 441, Message: "no such message you are replying to"}.String()) + } else { + return err + } + } + if !parentMessage.Thread.Valid { + var parentHeader mail.Header + err = json.Unmarshal([]byte(parentMessage.HeaderRaw), &parentHeader) + parentMessageID := parentHeader["Message-ID"] + a.Thread = sql.NullString{String: parentMessageID[0], Valid: true} + } else { + a.Thread = parentMessage.Thread + } + } + + err = h.backend.SaveArticle(a, strings.Split(a.Header.Get("Newsgroups"), ",")) + if err != nil { + return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 441, Message: err.Error()}.String()) + } + + return s.tconn.PrintfLine(protocol.MessageArticleReceived) } func (h *Handler) Handle(s *Session, message string, id uint) error { diff --git a/internal/server/nntp_server.go b/internal/server/nntp_server.go index 71673d6..d397212 100644 --- a/internal/server/nntp_server.go +++ b/internal/server/nntp_server.go @@ -96,7 +96,7 @@ func (ns *NNTPServer) Start() error { id, _ := uuid.NewUUID() closed := make(chan bool) - session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend)) + session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend, ns.cfg.Domain)) ns.sessionPoolMutex.Lock() ns.sessionPool[id.String()] = session ns.sessionPoolMutex.Unlock()