Implement POST command

This commit is contained in:
ChronosX88 2022-02-03 19:44:08 +03:00
parent e8aa350c53
commit 163f2bcba6
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
12 changed files with 287 additions and 30 deletions

View File

@ -32,7 +32,7 @@
- :x: `NEWNEWS` - :x: `NEWNEWS`
- :x: `NEXT` - :x: `NEXT`
- :x: `OVER` - :x: `OVER`
- :x: `POST` - :construction: `POST`
- :x: `STAT` - :x: `STAT`
## License ## License

View File

@ -1,6 +1,7 @@
address = "localhost" address = "localhost"
port = 1119 port = 1119
backend_type = "sqlite" backend_type = "sqlite"
domain = "localhost"
[sqlite] [sqlite]
path = "yans.db" path = "yans.db"

View File

@ -4,15 +4,13 @@ CREATE TABLE IF NOT EXISTS groups(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
group_name TEXT UNIQUE NOT NULL, group_name TEXT UNIQUE NOT NULL,
description TEXT, description TEXT,
created_at UNSIGNED BIG INT NOT NULL created_at DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE TABLE IF NOT EXISTS articles( CREATE TABLE IF NOT EXISTS articles(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
date INTEGER NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
path TEXT, header TEXT,
reply_to TEXT,
thread TEXT, thread TEXT,
subject TEXT NOT NULL,
body TEXT NOT NULL body TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS articles_to_groups( CREATE TABLE IF NOT EXISTS articles_to_groups(

View File

@ -3,6 +3,8 @@ package sqlite
import ( import (
"database/sql" "database/sql"
"embed" "embed"
"encoding/json"
"fmt"
"github.com/ChronosX88/yans/internal/config" "github.com/ChronosX88/yans/internal/config"
"github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/models"
"github.com/ChronosX88/yans/internal/utils" "github.com/ChronosX88/yans/internal/utils"
@ -11,6 +13,7 @@ import (
"github.com/mattn/go-sqlite3" "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/pressly/goose/v3" "github.com/pressly/goose/v3"
"strings"
) )
//go:embed migrations/*.sql //go:embed migrations/*.sql
@ -91,5 +94,43 @@ func (sb *SQLiteBackend) GetGroup(groupName string) (models.Group, error) {
func (sb *SQLiteBackend) GetNewGroupsSince(timestamp int64) ([]models.Group, error) { func (sb *SQLiteBackend) GetNewGroupsSince(timestamp int64) ([]models.Group, error) {
var groups []models.Group var groups []models.Group
return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE created_at > ?", timestamp) return groups, sb.db.Select(&groups, "SELECT * FROM groups WHERE created_at > datetime(?, 'unixepoch')", timestamp)
}
func (sb *SQLiteBackend) SaveArticle(a models.Article, groups []string) error {
res, err := sb.db.Exec("INSERT INTO articles (header, body, thread) VALUES (?, ?, ?)", a.HeaderRaw, a.Body, a.Thread)
articleID, err := res.LastInsertId()
if err != nil {
return err
}
var groupIDs []int
for _, v := range groups {
v = strings.TrimSpace(v)
g, err := sb.GetGroup(v)
if err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no such newsgroup")
} else {
return err
}
}
groupIDs = append(groupIDs, g.ID)
}
for _, v := range groupIDs {
_, err = sb.db.Exec("INSERT INTO articles_to_groups (article_id, group_id) VALUES (?, ?)", articleID, v)
if err != nil {
return err
}
}
return err
}
func (sb *SQLiteBackend) GetArticle(messageID string) (models.Article, error) {
var a models.Article
if err := sb.db.Get(&a, "SELECT * FROM articles WHERE json_extract(articles.header, '$.Message-ID[0]') = ?", messageID); err != nil {
return a, err
}
return a, json.Unmarshal([]byte(a.HeaderRaw), &a.Header)
} }

View File

@ -14,4 +14,6 @@ type StorageBackend interface {
GetArticlesCount(g models.Group) (int, error) GetArticlesCount(g models.Group) (int, error)
GetGroupLowWaterMark(g models.Group) (int, error) GetGroupLowWaterMark(g models.Group) (int, error)
GetGroupHighWaterMark(g models.Group) (int, error) GetGroupHighWaterMark(g models.Group) (int, error)
SaveArticle(article models.Article, groups []string) error
GetArticle(messageID string) (models.Article, error)
} }

View File

@ -13,6 +13,7 @@ type Config struct {
Address string `toml:"address"` Address string `toml:"address"`
Port int `toml:"port"` Port int `toml:"port"`
BackendType string `toml:"backend_type"` BackendType string `toml:"backend_type"`
Domain string `toml:"domain"`
SQLite SQLiteBackendConfig `toml:"sqlite"` SQLite SQLiteBackendConfig `toml:"sqlite"`
} }

View File

@ -0,0 +1,81 @@
package models
import (
"database/sql"
"net/textproto"
"time"
)
type Article struct {
ID int `db:"id"`
CreatedAt time.Time `db:"created_at"`
HeaderRaw string `db:"header"`
Header textproto.MIMEHeader `db:"-"`
Body string `db:"body"`
Thread sql.NullString `db:"thread"`
}
//func ParseArticle(lines []string) (Article, error) {
// article := Article{}
// headerBlock := true
// for _, v := range lines {
// if v == "" {
// headerBlock = false
// }
//
// if headerBlock {
// kv := strings.Split(v, ":")
// if len(kv) < 2 {
// return Article{}, fmt.Errorf("invalid header format")
// }
//
// kv[0] = strings.TrimSpace(kv[0])
// kv[1] = strings.TrimSpace(kv[1])
//
// if !protocol.IsMessageHeaderAllowed(kv[0]) {
// return Article{}, fmt.Errorf("invalid header element")
// }
// if kv[1] == "" {
// return Article{}, fmt.Errorf("header value should not be empty")
// }
//
// switch kv[0] {
// case "Archive":
// {
// if kv[1] == "yes" {
// article.Archive = true
// } else {
// article.Archive = false
// }
// }
// case "Injection-Date":
// {
// date, err := mail.ParseDate(kv[1])
// if err != nil {
// return Article{}, err
// }
// article.InjectionDate = date
// }
// case "Date":
// {
// date, err := mail.ParseDate(kv[1])
// if err != nil {
// return Article{}, err
// }
// article.Date = date
// }
// case "Expires":
// {
// date, err := mail.ParseDate(kv[1])
// if err != nil {
// return Article{}, err
// }
// article.Expires = date
// }
// }
//
// } else {
// }
// }
// return article, nil
//}

View File

@ -1,8 +1,10 @@
package models package models
import "time"
type Group struct { type Group struct {
ID int `db:"id"` ID int `db:"id"`
GroupName string `db:"group_name"` GroupName string `db:"group_name"`
Description *string `db:"description"` Description *string `db:"description"`
CreatedAt uint64 `db:"created_at"` CreatedAt time.Time `db:"created_at"`
} }

View File

@ -1,5 +1,46 @@
package protocol package protocol
var (
ErrSyntaxError = NNTPResponse{Code: 501, Message: "Syntax Error"}
)
func IsMessageHeaderAllowed(headerName string) bool {
switch headerName {
case
"Date",
"From",
"Message-ID",
"Newsgroups",
"Path",
"Subject",
"Comments",
"Keywords",
"In-Reply-To",
"Sender",
"MIME-Version",
"Content-Type",
"Content-Transfer-Encoding",
"Content-Disposition",
"Content-Language",
"Approved",
"Archive",
"Control",
"Distribution",
"Expires",
"Followup-To",
"Injection-Date",
"Injection-Info",
"Organization",
"References",
"Summary",
"Supersedes",
"User-Agent",
"Xref":
return true
}
return false
}
const ( const (
CRLF = "\r\n" CRLF = "\r\n"
MultilineEnding = "." MultilineEnding = "."
@ -13,6 +54,7 @@ const (
CommandList = "LIST" CommandList = "LIST"
CommandGroup = "GROUP" CommandGroup = "GROUP"
CommandNewGroups = "NEWGROUPS" CommandNewGroups = "NEWGROUPS"
CommandPost = "POST"
) )
const ( const (
@ -35,7 +77,7 @@ const (
MessageUnknownCommand = "500 Unknown command" MessageUnknownCommand = "500 Unknown command"
MessageErrorHappened = "403 Failed to process command:" MessageErrorHappened = "403 Failed to process command:"
MessageListOfNewsgroupsFollows = "215 list of newsgroups follows" MessageListOfNewsgroupsFollows = "215 list of newsgroups follows"
MessageNewGroupsListOfNewsgroupsFollows = "231 list of new newsgroups follows"
MessageSyntaxError = "501 Syntax Error"
MessageNoSuchGroup = "411 No such newsgroup" MessageNoSuchGroup = "411 No such newsgroup"
MessageInputArticle = "340 Input article; end with <CR-LF>.<CR-LF>"
MessageArticleReceived = "240 Article received OK"
) )

View File

@ -0,0 +1,12 @@
package protocol
import "fmt"
type NNTPResponse struct {
Code int
Message string
}
func (nr NNTPResponse) String() string {
return fmt.Sprintf("%d %s", nr.Code, nr.Message)
}

View File

@ -2,10 +2,14 @@ package server
import ( import (
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"github.com/ChronosX88/yans/internal/backend" "github.com/ChronosX88/yans/internal/backend"
"github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/models"
"github.com/ChronosX88/yans/internal/protocol" "github.com/ChronosX88/yans/internal/protocol"
"github.com/google/uuid"
"io"
"net/mail"
"strings" "strings"
"time" "time"
) )
@ -13,9 +17,10 @@ import (
type Handler struct { type Handler struct {
handlers map[string]func(s *Session, arguments []string, id uint) error handlers map[string]func(s *Session, arguments []string, id uint) error
backend backend.StorageBackend backend backend.StorageBackend
serverDomain string
} }
func NewHandler(b backend.StorageBackend) *Handler { func NewHandler(b backend.StorageBackend, serverDomain string) *Handler {
h := &Handler{} h := &Handler{}
h.backend = b h.backend = b
h.handlers = map[string]func(s *Session, arguments []string, id uint) error{ h.handlers = map[string]func(s *Session, arguments []string, id uint) error{
@ -26,7 +31,9 @@ func NewHandler(b backend.StorageBackend) *Handler {
protocol.CommandMode: h.handleModeReader, protocol.CommandMode: h.handleModeReader,
protocol.CommandGroup: h.handleGroup, protocol.CommandGroup: h.handleGroup,
protocol.CommandNewGroups: h.handleNewGroups, protocol.CommandNewGroups: h.handleNewGroups,
protocol.CommandPost: h.handlePost,
} }
h.serverDomain = serverDomain
return h return h
} }
@ -123,7 +130,7 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error {
} }
default: default:
{ {
return s.tconn.PrintfLine(protocol.MessageSyntaxError) return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
} }
} }
@ -134,7 +141,7 @@ func (h *Handler) handleList(s *Session, arguments []string, id uint) error {
func (h *Handler) handleModeReader(s *Session, arguments []string, id uint) error { func (h *Handler) handleModeReader(s *Session, arguments []string, id uint) error {
if len(arguments) == 0 || arguments[0] != "READER" { if len(arguments) == 0 || arguments[0] != "READER" {
return s.tconn.PrintfLine(protocol.MessageSyntaxError) return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
} }
(&s.capabilities).Remove(protocol.ModeReaderCapability) (&s.capabilities).Remove(protocol.ModeReaderCapability)
@ -151,7 +158,7 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error {
defer s.tconn.EndResponse(id) defer s.tconn.EndResponse(id)
if len(arguments) == 0 || len(arguments) > 1 { if len(arguments) == 0 || len(arguments) > 1 {
return s.tconn.PrintfLine(protocol.MessageSyntaxError) return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
} }
g, err := h.backend.GetGroup(arguments[0]) g, err := h.backend.GetGroup(arguments[0])
@ -177,7 +184,10 @@ func (h *Handler) handleGroup(s *Session, arguments []string, id uint) error {
s.currentGroup = &g s.currentGroup = &g
return s.tconn.PrintfLine("211 %d %d %d %s", articlesCount, lowWaterMark, highWaterMark, g.GroupName) return s.tconn.PrintfLine(protocol.NNTPResponse{
Code: 211,
Message: fmt.Sprintf("%d %d %d %s", articlesCount, lowWaterMark, highWaterMark, g.GroupName),
}.String())
} }
func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error { func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error {
@ -185,7 +195,7 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
defer s.tconn.EndResponse(id) defer s.tconn.EndResponse(id)
if len(arguments) < 2 || len(arguments) > 3 { if len(arguments) < 2 || len(arguments) > 3 {
return s.tconn.PrintfLine(protocol.MessageSyntaxError) return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
} }
dateString := arguments[0] + " " + arguments[1] dateString := arguments[0] + " " + arguments[1]
@ -208,7 +218,7 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
return err return err
} }
} else { } else {
return s.tconn.PrintfLine(protocol.MessageSyntaxError) return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
} }
g, err := h.backend.GetNewGroupsSince(date.Unix()) g, err := h.backend.GetNewGroupsSince(date.Unix())
@ -216,9 +226,8 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
return err return err
} }
var sb strings.Builder dw := s.tconn.DotWriter()
dw.Write([]byte(protocol.NNTPResponse{Code: 231, Message: "list of new newsgroups follows"}.String() + protocol.CRLF))
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
for _, v := range g { for _, v := range g {
// TODO set actual post permission status // TODO set actual post permission status
c, err := h.backend.GetArticlesCount(v) c, err := h.backend.GetArticlesCount(v)
@ -234,14 +243,82 @@ func (h *Handler) handleNewGroups(s *Session, arguments []string, id uint) error
if err != nil { if err != nil {
return err return err
} }
sb.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark))) dw.Write([]byte(fmt.Sprintf("%s %d %d n"+protocol.CRLF, v.GroupName, highWaterMark, lowWaterMark)))
} else { } else {
sb.Write([]byte(fmt.Sprintf("%s 0 1 n"+protocol.CRLF, v.GroupName))) dw.Write([]byte(fmt.Sprintf("%s 0 1 n"+protocol.CRLF, v.GroupName)))
} }
} }
sb.Write([]byte(protocol.MultilineEnding))
return s.tconn.PrintfLine(sb.String()) return dw.Close()
}
func (h *Handler) handlePost(s *Session, arguments []string, id uint) error {
s.tconn.StartResponse(id)
defer s.tconn.EndResponse(id)
if len(arguments) != 0 {
return s.tconn.PrintfLine(protocol.ErrSyntaxError.String())
}
if err := s.tconn.PrintfLine(protocol.MessageInputArticle); err != nil {
return err
}
headers, err := s.tconn.ReadMIMEHeader()
if err != nil {
return err
}
if err != nil {
return err
}
// generate message id
messageID := fmt.Sprintf("<%s@%s>", uuid.New().String(), h.serverDomain)
headers["Message-ID"] = []string{messageID}
headerJson, err := json.Marshal(headers)
if err != nil {
return err
}
a := models.Article{}
a.HeaderRaw = string(headerJson)
a.Header = headers
dr := s.tconn.DotReader()
// TODO handle multipart message
body, err := io.ReadAll(dr)
if err != nil {
return err
}
a.Body = string(body)
// set thread property
if headers.Get("In-Reply-To") != "" {
parentMessage, err := h.backend.GetArticle(headers.Get("In-Reply-To"))
if err != nil {
if err == sql.ErrNoRows {
return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 441, Message: "no such message you are replying to"}.String())
} else {
return err
}
}
if !parentMessage.Thread.Valid {
var parentHeader mail.Header
err = json.Unmarshal([]byte(parentMessage.HeaderRaw), &parentHeader)
parentMessageID := parentHeader["Message-ID"]
a.Thread = sql.NullString{String: parentMessageID[0], Valid: true}
} else {
a.Thread = parentMessage.Thread
}
}
err = h.backend.SaveArticle(a, strings.Split(a.Header.Get("Newsgroups"), ","))
if err != nil {
return s.tconn.PrintfLine(protocol.NNTPResponse{Code: 441, Message: err.Error()}.String())
}
return s.tconn.PrintfLine(protocol.MessageArticleReceived)
} }
func (h *Handler) Handle(s *Session, message string, id uint) error { func (h *Handler) Handle(s *Session, message string, id uint) error {

View File

@ -96,7 +96,7 @@ func (ns *NNTPServer) Start() error {
id, _ := uuid.NewUUID() id, _ := uuid.NewUUID()
closed := make(chan bool) closed := make(chan bool)
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend)) session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend, ns.cfg.Domain))
ns.sessionPoolMutex.Lock() ns.sessionPoolMutex.Lock()
ns.sessionPool[id.String()] = session ns.sessionPool[id.String()] = session
ns.sessionPoolMutex.Unlock() ns.sessionPoolMutex.Unlock()