mirror of
https://github.com/ChronosX88/yans.git
synced 2025-01-22 09:26:28 +00:00
Implement session pool, refactor connection/command handling
This commit is contained in:
parent
9f14adfb95
commit
0e6cd28925
1
go.mod
1
go.mod
@ -4,6 +4,7 @@ go 1.17
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.0.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/jmoiron/sqlx v1.3.4
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/pressly/goose/v3 v3.5.0
|
||||
|
@ -54,13 +54,21 @@ type Capability struct {
|
||||
|
||||
type Capabilities []Capability
|
||||
|
||||
func (cs Capabilities) Add(c Capability) {
|
||||
for _, v := range cs {
|
||||
func (cs *Capabilities) Add(c Capability) {
|
||||
for _, v := range *cs {
|
||||
if v.Type == c.Type {
|
||||
return // allowed only unique items
|
||||
}
|
||||
}
|
||||
cs = append(cs, c)
|
||||
*cs = append(*cs, c)
|
||||
}
|
||||
|
||||
func (cs *Capabilities) Remove(ct CapabilityType) {
|
||||
for i, v := range *cs {
|
||||
if v.Type == ct {
|
||||
*cs = append((*cs)[:i], (*cs)[i+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cs Capabilities) String() string {
|
||||
|
@ -31,6 +31,6 @@ const (
|
||||
MessageReaderModePostingProhibited = "201 Reader mode, posting prohibited"
|
||||
MessageNNTPServiceExitsNormally = "205 NNTP Service exits normally"
|
||||
MessageUnknownCommand = "500 Unknown command"
|
||||
MessageErrorHappened = "403 Failed to process command: "
|
||||
MessageErrorHappened = "403 Failed to process command:"
|
||||
MessageListOfNewsgroupsFollows = "215 list of newsgroups follows"
|
||||
)
|
||||
|
115
internal/server/handler.go
Normal file
115
internal/server/handler.go
Normal file
@ -0,0 +1,115 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/ChronosX88/yans/internal/models"
|
||||
"github.com/ChronosX88/yans/internal/protocol"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
handlers map[string]func(s *Session, arguments []string) error
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewHandler(db *sqlx.DB) *Handler {
|
||||
h := &Handler{}
|
||||
h.db = db
|
||||
h.handlers = map[string]func(s *Session, arguments []string) error{
|
||||
protocol.CommandCapabilities: h.handleCapabilities,
|
||||
protocol.CommandDate: h.handleDate,
|
||||
protocol.CommandQuit: h.handleQuit,
|
||||
protocol.CommandList: h.handleList,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handler) handleCapabilities(s *Session, arguments []string) error {
|
||||
return s.tconn.PrintfLine(Capabilities.String())
|
||||
}
|
||||
|
||||
func (h *Handler) handleDate(s *Session, arguments []string) error {
|
||||
return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405"))
|
||||
}
|
||||
|
||||
func (h *Handler) handleQuit(s *Session, arguments []string) error {
|
||||
s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally)
|
||||
s.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleList(s *Session, arguments []string) error {
|
||||
sb := strings.Builder{}
|
||||
|
||||
listType := ""
|
||||
if len(arguments) != 0 {
|
||||
listType = arguments[0]
|
||||
}
|
||||
|
||||
switch listType {
|
||||
case "":
|
||||
fallthrough
|
||||
case "ACTIVE":
|
||||
{
|
||||
groups, err := h.listGroups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
|
||||
for _, v := range groups {
|
||||
// TODO set high/low mark and posting status to actual values
|
||||
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
|
||||
}
|
||||
}
|
||||
case "NEWSGROUPS":
|
||||
{
|
||||
groups, err := h.listGroups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range groups {
|
||||
desc := ""
|
||||
if v.Description == nil {
|
||||
desc = "No description"
|
||||
} else {
|
||||
desc = *v.Description
|
||||
}
|
||||
sb.Write([]byte(fmt.Sprintf("%s %s"+protocol.CRLF, v.GroupName, desc)))
|
||||
}
|
||||
}
|
||||
default:
|
||||
{
|
||||
return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
|
||||
}
|
||||
}
|
||||
|
||||
sb.Write([]byte(protocol.MultilineEnding))
|
||||
|
||||
return s.tconn.PrintfLine(sb.String())
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(s *Session, message string) error {
|
||||
splittedMessage := strings.Split(message, " ")
|
||||
for i, v := range splittedMessage {
|
||||
splittedMessage[i] = strings.TrimSpace(v)
|
||||
}
|
||||
cmdName := splittedMessage[0]
|
||||
handler, ok := h.handlers[cmdName]
|
||||
if !ok {
|
||||
return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
|
||||
}
|
||||
return handler(s, splittedMessage[1:])
|
||||
}
|
||||
|
||||
// TODO Refactor to "storage backend" entity
|
||||
func (h *Handler) listGroups() ([]models.Group, error) {
|
||||
var groups []models.Group
|
||||
return groups, h.db.Select(&groups, "SELECT * FROM groups")
|
||||
}
|
||||
|
||||
func (h *Handler) getArticlesCount(g models.Group) (int, error) {
|
||||
var count int
|
||||
return count, h.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
|
||||
}
|
@ -6,17 +6,14 @@ import (
|
||||
"github.com/ChronosX88/yans/internal"
|
||||
"github.com/ChronosX88/yans/internal/common"
|
||||
"github.com/ChronosX88/yans/internal/config"
|
||||
"github.com/ChronosX88/yans/internal/models"
|
||||
"github.com/ChronosX88/yans/internal/protocol"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/pressly/goose/v3"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -36,6 +33,9 @@ type NNTPServer struct {
|
||||
port int
|
||||
|
||||
db *sqlx.DB
|
||||
|
||||
sessionPool map[string]*Session
|
||||
sessionPoolMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
|
||||
@ -55,10 +55,11 @@ func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ns := &NNTPServer{
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
port: cfg.Port,
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
port: cfg.Port,
|
||||
db: db,
|
||||
sessionPool: map[string]*Session{},
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
@ -81,7 +82,30 @@ func (ns *NNTPServer) Start() error {
|
||||
log.Println(err)
|
||||
}
|
||||
log.Printf("Client %s has connected!", conn.RemoteAddr().String())
|
||||
go ns.handleNewConnection(ctx, conn)
|
||||
|
||||
id, _ := uuid.NewUUID()
|
||||
closed := make(chan bool)
|
||||
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.db))
|
||||
ns.sessionPoolMutex.Lock()
|
||||
ns.sessionPool[id.String()] = session
|
||||
ns.sessionPoolMutex.Unlock()
|
||||
go func(ctx context.Context, id string, closed chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break
|
||||
case _, ok := <-closed:
|
||||
{
|
||||
if !ok {
|
||||
ns.sessionPoolMutex.Lock()
|
||||
delete(ns.sessionPool, id)
|
||||
ns.sessionPoolMutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}(ctx, id.String(), closed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -90,133 +114,6 @@ func (ns *NNTPServer) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ns *NNTPServer) handleNewConnection(ctx context.Context, conn net.Conn) {
|
||||
_, err := conn.Write([]byte(protocol.MessageNNTPServiceReadyPostingProhibited + protocol.CRLF))
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
tconn := textproto.NewConn(conn)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break
|
||||
default:
|
||||
{
|
||||
message, err := tconn.ReadLine()
|
||||
if err != nil {
|
||||
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
|
||||
log.Printf("Client %s has diconnected!", conn.RemoteAddr().String())
|
||||
} else {
|
||||
log.Print(err)
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Printf("Received message from %s: %s", conn.RemoteAddr().String(), string(message))
|
||||
err = ns.handleMessage(tconn, message)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ns *NNTPServer) handleMessage(conn *textproto.Conn, msg string) error {
|
||||
splittedMessage := strings.Split(msg, " ")
|
||||
command := splittedMessage[0]
|
||||
|
||||
reply := ""
|
||||
quit := false
|
||||
|
||||
switch command {
|
||||
case protocol.CommandCapabilities:
|
||||
{
|
||||
reply = Capabilities.String()
|
||||
break
|
||||
}
|
||||
case protocol.CommandDate:
|
||||
{
|
||||
reply = fmt.Sprintf("111 %s", time.Now().UTC().Format("20060102150405"))
|
||||
break
|
||||
}
|
||||
case protocol.CommandQuit:
|
||||
{
|
||||
reply = protocol.MessageNNTPServiceExitsNormally
|
||||
quit = true
|
||||
break
|
||||
}
|
||||
case protocol.CommandMode:
|
||||
{
|
||||
if splittedMessage[1] == "READER" {
|
||||
// TODO actually switch current conn to reader mode
|
||||
reply = protocol.MessageReaderModePostingProhibited
|
||||
} else {
|
||||
reply = protocol.MessageUnknownCommand
|
||||
}
|
||||
break
|
||||
}
|
||||
case protocol.CommandList:
|
||||
{
|
||||
groups, err := ns.listGroups()
|
||||
if err != nil {
|
||||
reply = protocol.MessageErrorHappened + err.Error()
|
||||
log.Println(err)
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
|
||||
if len(splittedMessage) == 1 || splittedMessage[1] == "ACTIVE" {
|
||||
for _, v := range groups {
|
||||
// TODO set high/low mark and posting status to actual values
|
||||
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
|
||||
}
|
||||
} else if splittedMessage[1] == "NEWSGROUPS" {
|
||||
for _, v := range groups {
|
||||
desc := ""
|
||||
if v.Description == nil {
|
||||
desc = "No description"
|
||||
} else {
|
||||
desc = *v.Description
|
||||
}
|
||||
sb.Write([]byte(fmt.Sprintf("%s %s\r\n", v.GroupName, desc)))
|
||||
}
|
||||
} else {
|
||||
reply = protocol.MessageUnknownCommand
|
||||
break
|
||||
}
|
||||
|
||||
sb.Write([]byte("."))
|
||||
reply = sb.String()
|
||||
}
|
||||
default:
|
||||
{
|
||||
reply = protocol.MessageUnknownCommand
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err := conn.PrintfLine(reply)
|
||||
if quit {
|
||||
conn.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ns *NNTPServer) listGroups() ([]models.Group, error) {
|
||||
var groups []models.Group
|
||||
return groups, ns.db.Select(&groups, "SELECT * FROM groups")
|
||||
}
|
||||
|
||||
func (ns *NNTPServer) getArticlesCount(g models.Group) (int, error) {
|
||||
var count int
|
||||
return count, ns.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
|
||||
}
|
||||
|
||||
func (ns *NNTPServer) Stop() {
|
||||
ns.cancelFunc()
|
||||
}
|
||||
|
86
internal/server/session.go
Normal file
86
internal/server/session.go
Normal file
@ -0,0 +1,86 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/ChronosX88/yans/internal/protocol"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/textproto"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ctx context.Context
|
||||
capabilities protocol.Capabilities
|
||||
conn net.Conn
|
||||
tconn *textproto.Conn
|
||||
id string
|
||||
closed chan<- bool
|
||||
h *Handler
|
||||
}
|
||||
|
||||
func NewSession(ctx context.Context, conn net.Conn, caps protocol.Capabilities, id string, closed chan<- bool, handler *Handler) (*Session, error) {
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
close(closed)
|
||||
}
|
||||
}()
|
||||
|
||||
tconn := textproto.NewConn(conn)
|
||||
s := &Session{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
tconn: tconn,
|
||||
capabilities: caps,
|
||||
id: id,
|
||||
closed: closed,
|
||||
h: handler,
|
||||
}
|
||||
|
||||
go s.loop()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Session) loop() {
|
||||
defer func() {
|
||||
close(s.closed)
|
||||
}()
|
||||
|
||||
err := s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // by default access mode is read-only
|
||||
if err != nil {
|
||||
s.conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
break
|
||||
default:
|
||||
{
|
||||
message, err := s.tconn.ReadLine()
|
||||
if err != nil {
|
||||
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
|
||||
log.Printf("Client %s has diconnected!", s.conn.RemoteAddr().String())
|
||||
} else {
|
||||
log.Print(err)
|
||||
s.conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging
|
||||
err = s.h.Handle(s, message)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error())
|
||||
s.conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user