Implement POST command

This commit is contained in:
ChronosX88 2022-02-03 19:44:08 +03:00
parent e8aa350c53
commit 163f2bcba6
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
12 changed files with 287 additions and 30 deletions

View File

@ -32,7 +32,7 @@
- :x: `NEWNEWS`
- :x: `NEXT`
- :x: `OVER`
- :x: `POST`
- :construction: `POST`
- :x: `STAT`
## License

View File

@ -1,6 +1,7 @@
address = "localhost"
port = 1119
backend_type = "sqlite"
domain = "localhost"
[sqlite]
path = "yans.db"

View File

@ -4,15 +4,13 @@ CREATE TABLE IF NOT EXISTS groups(
id INTEGER PRIMARY KEY AUTOINCREMENT,
group_name TEXT UNIQUE NOT NULL,
description TEXT,
created_at UNSIGNED BIG INT NOT NULL
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS articles(
id INTEGER PRIMARY KEY AUTOINCREMENT,
date INTEGER NOT NULL,
path TEXT,
reply_to TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
header TEXT,
thread TEXT,
subject TEXT NOT NULL,
body TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS articles_to_groups(

View File

@ -3,6 +3,8 @@ package sqlite
import (
"database/sql"
"embed"
"encoding/json"
"fmt"
"github.com/ChronosX88/yans/internal/config"
"github.com/ChronosX88/yans/internal/models"
"github.com/ChronosX88/yans/internal/utils"
@ -11,6 +13,7 @@ import (
"github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
"github.com/pressly/goose/v3"
"strings"
)
//go:embed migrations/*.sql
@ -91,5 +94,43 @@ func (sb *SQLiteBackend) GetGroup(groupName string) (models.Group, error) {
func (sb *SQLiteBackend) GetNewGroupsSince(timestamp int64) ([]models.Group, error) {
var groups []models.Group
return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE created_at > ?", timestamp)
return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE created_at > datetime(?, 'unixepoch')", timestamp)
}
func (sb *SQLiteBackend) SaveArticle(a models.Article, groups []string) error {
res, err := sb.db.Exec("INSERT INTO articles (header, body, thread) VALUES (?, ?, ?)", a.HeaderRaw, a.Body, a.Thread)
articleID, err := res.LastInsertId()
if err != nil {
return err
}
var groupIDs []int
for _, v := range groups {
v = strings.TrimSpace(v)
g, err := sb.GetGroup(v)
if err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no such newsgroup")
} else {
return err
}
}
groupIDs = append(groupIDs, g.ID)
}
for _, v := range groupIDs {
_, err = sb.db.Exec("INSERT INTO articles_to_groups (article_id, group_id) VALUES (?, ?)", articleID, v)
if err != nil {
return err
}
}
return err
}
func (sb *SQLiteBackend) GetArticle(messageID string) (models.Article, error) {
var a models.Article
if err := sb.db.Get(&a, "SELECT * FROM articles WHERE json_extract(articles.header, '$.Message-ID[0]') = ?", messageID); err != nil {
return a, err
}
return a, json.Unmarshal([]byte(a.HeaderRaw), &a.Header)
}

View File

@ -14,4 +14,6 @@ type StorageBackend interface {
GetArticlesCount(g models.Group) (int, error)
GetGroupLowWaterMark(g models.Group) (int, error)
GetGroupHighWaterMark(g models.Group) (int, error)
SaveArticle(article models.Article, groups []string) error
GetArticle(messageID string) (models.Article, error)
}

View File

@ -13,6 +13,7 @@ type Config struct {
Address string `toml:"address"`
Port int `toml:"port"`
BackendType string `toml:"backend_type"`
Domain string `toml:"domain"`
SQLite SQLiteBackendConfig `toml:"sqlite"`
}

View File

@ -0,0 +1,81 @@
package models
import (
"database/sql"
"net/textproto"
"time"
)
type Article struct {
ID int `db:"id"`
CreatedAt time.Time `db:"created_at"`
HeaderRaw string `db:"header"`
Header textproto.MIMEHeader `db:"-"`
Body string `db:"body"`
Thread sql.NullString `db:"thread"`
}
//func ParseArticle(lines []string) (Article, error) {
// article := Article{}
// headerBlock := true
// for _, v := range lines {
// if v == "" {
// headerBlock = false
// }
//
// if headerBlock {
// kv := strings.Split(v, ":")
// if len(kv) < 2 {
// return Article{}, fmt.Errorf("invalid header format")
// }
//
// kv[0] = strings.TrimSpace(kv[0])
// kv[1] = strings.TrimSpace(kv[1])
//
// if !protocol.IsMessageHeaderAllowed(kv[0]) {
// return Article{}, fmt.Errorf("invalid header element")
// }
// if kv[1] == "" {
// return Article{}, fmt.Errorf("header value should not be empty")
// }
//
// switch kv[0] {
// case "Archive":
// {
// if kv[1] == "yes" {
// article.Archive = true
// } else {
// article.Archive = false
// }
// }
// case "Injection-Date":
// {
// date, err := mail.ParseDate(kv[1])
// if err != nil {
// return Article{}, err
// }
// article.InjectionDate = date
// }
// case "Date":
// {
// date, err := mail.ParseDate(kv[1])
// if err != nil {
// return Article{}, err
// }
// article.Date = date
// }
// case "Expires":
// {
// date, err := mail.ParseDate(kv[1])
// if err != nil {
// return Article{}, err
// }
// article.Expires = date
// }
// }
//
// } else {
// }
// }
// return article, nil
//}

View File

@ -1,8 +1,10 @@
package models
import "time"
type Group struct {
ID int `db:"id"`
GroupName string `db:"group_name"`
Description *string `db:"description"`
CreatedAt uint64 `db:"created_at"`
ID int `db:"id"`
GroupName string `db:"group_name"`
Description *string `db:"description"`
CreatedAt time.Time `db:"created_at"`
}

View File

@ -1,5 +1,46 @@
package protocol
var (
ErrSyntaxError = NNTPResponse{Code: 501, Message: "Syntax Error"}
)
func IsMessageHeaderAllowed(headerName string) bool {
switch headerName {
case
"Date",
"From",
"Message-ID",
"Newsgroups",
"Path",
"Subject",
"Comments",
"Keywords",
"In-Reply-To",
"Sender",
"MIME-Version",
"Content-Type",
"Content-Transfer-Encoding",
"Content-Disposition",
"Content-Language",
"Approved",
"Archive",
"Control",
"Distribution",
"Expires",
"Followup-To",
"Injection-Date",
"Injection-Info",
"Organization",
"References",
"Summary",
"Supersedes",
"User-Agent",
"Xref":
return true
}
return false
}
const (
CRLF = "\r\n"
MultilineEnding = "."
@ -13,6 +54,7 @@ const (
CommandList = "LIST"
CommandGroup = "GROUP"
CommandNewGroups = "NEWGROUPS"
CommandPost = "POST"
)
const (
@ -35,7 +77,7 @@ const (
MessageUnknownCommand = "500 Unknown command"
MessageErrorHappened = "403 Failed to process command:"
MessageListOfNewsgroupsFollows = "215 list of newsgroups follows"
MessageNewGroupsListOfNewsgroupsFollows = "231 list of new newsgroups follows"
MessageSyntaxError = "501 Syntax Error"
MessageNoSuchGroup = "411 No such newsgroup"
MessageInputArticle = "340 Input article; end with <CR-LF>.<CR-LF>"
MessageArticleReceived = "240 Article received OK"
)

View File

@ -0,0 +1,12 @@
package protocol
import "fmt"
type NNTPResponse struct {
Code int
Message string
}
func (nr NNTPResponse) String() string {
return fmt.Sprintf("%d %s", nr.Code, nr.Message)
}

View File

@ -2,20 +2,25 @@ package server
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/ChronosX88/yans/internal/backend"
"github.com/ChronosX88/yans/internal/models"
"github.com/ChronosX88/yans/internal/protocol"
"github.com/google/uuid"
"io"
"net/mail"
"strings"
"time"
)
type Handler struct {
handlers map[string]func(s *Session, arguments []string, id uint) error
backend backend.StorageBackend
handlers map[string]func(s *Session, arguments []string, id uint) error
backend backend.StorageBackend
serverDomain string
}
func NewHandler(b backend.StorageBackend) *Handler {
func NewHandler(b backend.StorageBackend, serverDomain string) *Handler {
h := &Handler{}
h.backend = b
h.handlers = map[string]func(s *Session, arguments []string, id uint) error{
@ -26,7 +31,9 @@ func NewHandler(b backend.StorageBackend) *Handler {
protocol.CommandMode: h.handleModeReader,
protocol.CommandGroup: h.handleGroup,
protocol.CommandNewGroups: h.handleNewGroups,
protocol.CommandPost: h.handlePost,
}
h.serverDomain = serverDomain
return h
}
@ -123,7 +130,7 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error {
}
default:
{
return s.tconn.PrintfLine(protocol.MessageSyntaxError)
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
}
@ -134,7 +141,7 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error {
func (h *Handler) handleModeReader(s *Session, arguments []string, id uint) error {
if len(arguments) == 0 || arguments[0] != "READER" {
return s.tconn.PrintfLine(protocol.MessageSyntaxError)
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
(&s.capabilities).Remove(protocol.ModeReaderCapability)
@ -151,7 +158,7 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error {
defer s.tconn.EndResponse(id)
if len(arguments) == 0 || len(arguments) > 1 {
return s.tconn.PrintfLine(protocol.MessageSyntaxError)
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
g, err := h.backend.GetGroup(arguments[0])
@ -177,7 +184,10 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error {
s.currentGroup = &g
return s.tconn.PrintfLine("211 %d %d %d %s", articlesCount, lowWaterMark, highWaterMark, g.GroupName)
return s.tconn.PrintfLine(protocol.NNTPResponse{
Code: 211,
Message: fmt.Sprintf("%d %d %d %s", articlesCount, lowWaterMark, highWaterMark, g.GroupName),
}.String())
}
func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error {
@ -185,7 +195,7 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
defer s.tconn.EndResponse(id)
if len(arguments) < 2 || len(arguments) > 3 {
return s.tconn.PrintfLine(protocol.MessageSyntaxError)
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
dateString := arguments[0] + " " + arguments[1]
@ -208,7 +218,7 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
return err
}
} else {
return s.tconn.PrintfLine(protocol.MessageSyntaxError)
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
g, err := h.backend.GetNewGroupsSince(date.Unix())
@ -216,9 +226,8 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
return err
}
var sb strings.Builder
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
dw := s.tconn.DotWriter()
dw.Write([]byte(protocol.NNTPResponse{Code: 231, Message: "list of new newsgroups follows"}.String() + protocol.CRLF))
for _, v := range g {
// TODO set actual post permission status
c, err := h.backend.GetArticlesCount(v)
@ -234,14 +243,82 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
if err != nil {
return err
}
sb.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark)))
dw.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark)))
} else {
sb.Write([]byte(fmt.Sprintf("%s 0 1 n"+protocol.CRLF, v.GroupName)))
dw.Write([]byte(fmt.Sprintf("%s 0 1 n"+protocol.CRLF, v.GroupName)))
}
}
sb.Write([]byte(protocol.MultilineEnding))
return s.tconn.PrintfLine(sb.String())
return dw.Close()
}
func (h *Handler) handlePost(s *Session, arguments []string, id uint) error {
s.tconn.StartResponse(id)
defer s.tconn.EndResponse(id)
if len(arguments) != 0 {
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
if err := s.tconn.PrintfLine(protocol.MessageInputArticle); err != nil {
return err
}
headers, err := s.tconn.ReadMIMEHeader()
if err != nil {
return err
}
if err != nil {
return err
}
// generate message id
messageID := fmt.Sprintf("<%s@%s>", uuid.New().String(), h.serverDomain)
headers["Message-ID"] = []string{messageID}
headerJson, err := json.Marshal(headers)
if err != nil {
return err
}
a := models.Article{}
a.HeaderRaw = string(headerJson)
a.Header = headers
dr := s.tconn.DotReader()
// TODO handle multipart message
body, err := io.ReadAll(dr)
if err != nil {
return err
}
a.Body = string(body)
// set thread property
if headers.Get("In-Reply-To") != "" {
parentMessage, err := h.backend.GetArticle(headers.Get("In-Reply-To"))
if err != nil {
if err == sql.ErrNoRows {
return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 441, Message: "no such message you are replying to"}.String())
} else {
return err
}
}
if !parentMessage.Thread.Valid {
var parentHeader mail.Header
err = json.Unmarshal([]byte(parentMessage.HeaderRaw), &parentHeader)
parentMessageID := parentHeader["Message-ID"]
a.Thread = sql.NullString{String: parentMessageID[0], Valid: true}
} else {
a.Thread = parentMessage.Thread
}
}
err = h.backend.SaveArticle(a, strings.Split(a.Header.Get("Newsgroups"), ","))
if err != nil {
return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 441, Message: err.Error()}.String())
}
return s.tconn.PrintfLine(protocol.MessageArticleReceived)
}
func (h *Handler) Handle(s *Session, message string, id uint) error {

View File

@ -96,7 +96,7 @@ func (ns *NNTPServer) Start() error {
id, _ := uuid.NewUUID()
closed := make(chan bool)
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend))
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend, ns.cfg.Domain))
ns.sessionPoolMutex.Lock()
ns.sessionPool[id.String()] = session
ns.sessionPoolMutex.Unlock()