Implement session pool, refactor connection/command handling

This commit is contained in:
ChronosX88 2022-01-18 20:26:37 +03:00
parent 9f14adfb95
commit 0e6cd28925
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
6 changed files with 248 additions and 141 deletions

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.17
require (
github.com/BurntSushi/toml v1.0.0
github.com/google/uuid v1.3.0
github.com/jmoiron/sqlx v1.3.4
github.com/mattn/go-sqlite3 v1.14.10
github.com/pressly/goose/v3 v3.5.0

View File

@ -54,13 +54,21 @@ type Capability struct {
type Capabilities []Capability
func (cs Capabilities) Add(c Capability) {
for _, v := range cs {
func (cs *Capabilities) Add(c Capability) {
for _, v := range *cs {
if v.Type == c.Type {
return // allowed only unique items
}
}
cs = append(cs, c)
*cs = append(*cs, c)
}
func (cs *Capabilities) Remove(ct CapabilityType) {
for i, v := range *cs {
if v.Type == ct {
*cs = append((*cs)[:i], (*cs)[i+1:]...)
}
}
}
func (cs Capabilities) String() string {

View File

@ -31,6 +31,6 @@ const (
MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited"
MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally"
MessageUnknownCommand = "500 Unknown command"
MessageErrorHappened = "403 Failed to process command: "
MessageErrorHappened = "403 Failed to process command:"
MessageListOfNewsgroupsFollows = "215 list of newsgroups follows"
)

115
internal/server/handler.go Normal file
View File

@ -0,0 +1,115 @@
package server
import (
"fmt"
"github.com/ChronosX88/yans/internal/models"
"github.com/ChronosX88/yans/internal/protocol"
"github.com/jmoiron/sqlx"
"strings"
"time"
)
type Handler struct {
handlers map[string]func(s *Session, arguments []string) error
db *sqlx.DB
}
func NewHandler(db *sqlx.DB) *Handler {
h := &Handler{}
h.db = db
h.handlers = map[string]func(s *Session, arguments []string) error{
protocol.CommandCapabilities: h.handleCapabilities,
protocol.CommandDate: h.handleDate,
protocol.CommandQuit: h.handleQuit,
protocol.CommandList: h.handleList,
}
return h
}
func (h *Handler) handleCapabilities(s *Session, arguments []string) error {
return s.tconn.PrintfLine(Capabilities.String())
}
func (h *Handler) handleDate(s *Session, arguments []string) error {
return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405"))
}
func (h *Handler) handleQuit(s *Session, arguments []string) error {
s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally)
s.conn.Close()
return nil
}
func (h *Handler) handleList(s *Session, arguments []string) error {
sb := strings.Builder{}
listType := ""
if len(arguments) != 0 {
listType = arguments[0]
}
switch listType {
case "":
fallthrough
case "ACTIVE":
{
groups, err := h.listGroups()
if err != nil {
return err
}
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
for _, v := range groups {
// TODO set high/low mark and posting status to actual values
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
}
}
case "NEWSGROUPS":
{
groups, err := h.listGroups()
if err != nil {
return err
}
for _, v := range groups {
desc := ""
if v.Description == nil {
desc = "No description"
} else {
desc = *v.Description
}
sb.Write([]byte(fmt.Sprintf("%s %s"+protocol.CRLF, v.GroupName, desc)))
}
}
default:
{
return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
}
}
sb.Write([]byte(protocol.MultilineEnding))
return s.tconn.PrintfLine(sb.String())
}
func (h *Handler) Handle(s *Session, message string) error {
splittedMessage := strings.Split(message, " ")
for i, v := range splittedMessage {
splittedMessage[i] = strings.TrimSpace(v)
}
cmdName := splittedMessage[0]
handler, ok := h.handlers[cmdName]
if !ok {
return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
}
return handler(s, splittedMessage[1:])
}
// TODO Refactor to "storage backend" entity
func (h *Handler) listGroups() ([]models.Group, error) {
var groups []models.Group
return groups, h.db.Select(&groups, "SELECT * FROM groups")
}
func (h *Handler) getArticlesCount(g models.Group) (int, error) {
var count int
return count, h.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
}

View File

@ -6,17 +6,14 @@ import (
"github.com/ChronosX88/yans/internal"
"github.com/ChronosX88/yans/internal/common"
"github.com/ChronosX88/yans/internal/config"
"github.com/ChronosX88/yans/internal/models"
"github.com/ChronosX88/yans/internal/protocol"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/pressly/goose/v3"
"io"
"log"
"net"
"net/textproto"
"strings"
"time"
"sync"
)
var (
@ -36,6 +33,9 @@ type NNTPServer struct {
port int
db *sqlx.DB
sessionPool map[string]*Session
sessionPoolMutex sync.Mutex
}
func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
@ -55,10 +55,11 @@ func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
ctx, cancel := context.WithCancel(context.Background())
ns := &NNTPServer{
ctx: ctx,
cancelFunc: cancel,
port: cfg.Port,
db: db,
ctx: ctx,
cancelFunc: cancel,
port: cfg.Port,
db: db,
sessionPool: map[string]*Session{},
}
return ns, nil
}
@ -81,7 +82,30 @@ func (ns *NNTPServer) Start() error {
log.Println(err)
}
log.Printf("Client %s has connected!", conn.RemoteAddr().String())
go ns.handleNewConnection(ctx, conn)
id, _ := uuid.NewUUID()
closed := make(chan bool)
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.db))
ns.sessionPoolMutex.Lock()
ns.sessionPool[id.String()] = session
ns.sessionPoolMutex.Unlock()
go func(ctx context.Context, id string, closed chan bool) {
for {
select {
case <-ctx.Done():
break
case _, ok := <-closed:
{
if !ok {
ns.sessionPoolMutex.Lock()
delete(ns.sessionPool, id)
ns.sessionPoolMutex.Unlock()
return
}
}
}
}
}(ctx, id.String(), closed)
}
}
}
@ -90,133 +114,6 @@ func (ns *NNTPServer) Start() error {
return nil
}
func (ns *NNTPServer) handleNewConnection(ctx context.Context, conn net.Conn) {
_, err := conn.Write([]byte(protocol.MessageNNTPServiceReadyPostingProhibited + protocol.CRLF))
if err != nil {
log.Print(err)
conn.Close()
return
}
tconn := textproto.NewConn(conn)
for {
select {
case <-ctx.Done():
break
default:
{
message, err := tconn.ReadLine()
if err != nil {
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
log.Printf("Client %s has diconnected!", conn.RemoteAddr().String())
} else {
log.Print(err)
conn.Close()
}
return
}
log.Printf("Received message from %s: %s", conn.RemoteAddr().String(), string(message))
err = ns.handleMessage(tconn, message)
if err != nil {
log.Print(err)
conn.Close()
return
}
}
}
}
}
func (ns *NNTPServer) handleMessage(conn *textproto.Conn, msg string) error {
splittedMessage := strings.Split(msg, " ")
command := splittedMessage[0]
reply := ""
quit := false
switch command {
case protocol.CommandCapabilities:
{
reply = Capabilities.String()
break
}
case protocol.CommandDate:
{
reply = fmt.Sprintf("111 %s", time.Now().UTC().Format("20060102150405"))
break
}
case protocol.CommandQuit:
{
reply = protocol.MessageNNTPServiceExitsNormally
quit = true
break
}
case protocol.CommandMode:
{
if splittedMessage[1] == "READER" {
// TODO actually switch current conn to reader mode
reply = protocol.MessageReaderModePostingProhibited
} else {
reply = protocol.MessageUnknownCommand
}
break
}
case protocol.CommandList:
{
groups, err := ns.listGroups()
if err != nil {
reply = protocol.MessageErrorHappened + err.Error()
log.Println(err)
}
sb := strings.Builder{}
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
if len(splittedMessage) == 1 || splittedMessage[1] == "ACTIVE" {
for _, v := range groups {
// TODO set high/low mark and posting status to actual values
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
}
} else if splittedMessage[1] == "NEWSGROUPS" {
for _, v := range groups {
desc := ""
if v.Description == nil {
desc = "No description"
} else {
desc = *v.Description
}
sb.Write([]byte(fmt.Sprintf("%s %s\r\n", v.GroupName, desc)))
}
} else {
reply = protocol.MessageUnknownCommand
break
}
sb.Write([]byte("."))
reply = sb.String()
}
default:
{
reply = protocol.MessageUnknownCommand
break
}
}
err := conn.PrintfLine(reply)
if quit {
conn.Close()
}
return err
}
func (ns *NNTPServer) listGroups() ([]models.Group, error) {
var groups []models.Group
return groups, ns.db.Select(&groups, "SELECT * FROM groups")
}
func (ns *NNTPServer) getArticlesCount(g models.Group) (int, error) {
var count int
return count, ns.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
}
func (ns *NNTPServer) Stop() {
ns.cancelFunc()
}

View File

@ -0,0 +1,86 @@
package server
import (
"context"
"github.com/ChronosX88/yans/internal/protocol"
"io"
"log"
"net"
"net/textproto"
)
type Session struct {
ctx context.Context
capabilities protocol.Capabilities
conn net.Conn
tconn *textproto.Conn
id string
closed chan<- bool
h *Handler
}
func NewSession(ctx context.Context, conn net.Conn, caps protocol.Capabilities, id string, closed chan<- bool, handler *Handler) (*Session, error) {
var err error
defer func() {
if err != nil {
conn.Close()
close(closed)
}
}()
tconn := textproto.NewConn(conn)
s := &Session{
ctx: ctx,
conn: conn,
tconn: tconn,
capabilities: caps,
id: id,
closed: closed,
h: handler,
}
go s.loop()
return s, nil
}
func (s *Session) loop() {
defer func() {
close(s.closed)
}()
err := s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // by default access mode is read-only
if err != nil {
s.conn.Close()
return
}
for {
select {
case <-s.ctx.Done():
break
default:
{
message, err := s.tconn.ReadLine()
if err != nil {
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
log.Printf("Client %s has diconnected!", s.conn.RemoteAddr().String())
} else {
log.Print(err)
s.conn.Close()
}
return
}
log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging
err = s.h.Handle(s, message)
if err != nil {
log.Print(err)
s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error())
s.conn.Close()
return
}
}
}
}
}