mirror of
https://github.com/ChronosX88/yans.git
synced 2024-11-21 19:32:17 +00:00
Implement session pool, refactor connection/command handling
This commit is contained in:
parent
9f14adfb95
commit
0e6cd28925
1
go.mod
1
go.mod
@ -4,6 +4,7 @@ go 1.17
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/BurntSushi/toml v1.0.0
|
github.com/BurntSushi/toml v1.0.0
|
||||||
|
github.com/google/uuid v1.3.0
|
||||||
github.com/jmoiron/sqlx v1.3.4
|
github.com/jmoiron/sqlx v1.3.4
|
||||||
github.com/mattn/go-sqlite3 v1.14.10
|
github.com/mattn/go-sqlite3 v1.14.10
|
||||||
github.com/pressly/goose/v3 v3.5.0
|
github.com/pressly/goose/v3 v3.5.0
|
||||||
|
@ -54,13 +54,21 @@ type Capability struct {
|
|||||||
|
|
||||||
type Capabilities []Capability
|
type Capabilities []Capability
|
||||||
|
|
||||||
func (cs Capabilities) Add(c Capability) {
|
func (cs *Capabilities) Add(c Capability) {
|
||||||
for _, v := range cs {
|
for _, v := range *cs {
|
||||||
if v.Type == c.Type {
|
if v.Type == c.Type {
|
||||||
return // allowed only unique items
|
return // allowed only unique items
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cs = append(cs, c)
|
*cs = append(*cs, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Capabilities) Remove(ct CapabilityType) {
|
||||||
|
for i, v := range *cs {
|
||||||
|
if v.Type == ct {
|
||||||
|
*cs = append((*cs)[:i], (*cs)[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs Capabilities) String() string {
|
func (cs Capabilities) String() string {
|
||||||
|
115
internal/server/handler.go
Normal file
115
internal/server/handler.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/ChronosX88/yans/internal/models"
|
||||||
|
"github.com/ChronosX88/yans/internal/protocol"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
handlers map[string]func(s *Session, arguments []string) error
|
||||||
|
db *sqlx.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandler(db *sqlx.DB) *Handler {
|
||||||
|
h := &Handler{}
|
||||||
|
h.db = db
|
||||||
|
h.handlers = map[string]func(s *Session, arguments []string) error{
|
||||||
|
protocol.CommandCapabilities: h.handleCapabilities,
|
||||||
|
protocol.CommandDate: h.handleDate,
|
||||||
|
protocol.CommandQuit: h.handleQuit,
|
||||||
|
protocol.CommandList: h.handleList,
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleCapabilities(s *Session, arguments []string) error {
|
||||||
|
return s.tconn.PrintfLine(Capabilities.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleDate(s *Session, arguments []string) error {
|
||||||
|
return s.tconn.PrintfLine("111 %s", time.Now().UTC().Format("20060102150405"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleQuit(s *Session, arguments []string) error {
|
||||||
|
s.tconn.PrintfLine(protocol.MessageNNTPServiceExitsNormally)
|
||||||
|
s.conn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleList(s *Session, arguments []string) error {
|
||||||
|
sb := strings.Builder{}
|
||||||
|
|
||||||
|
listType := ""
|
||||||
|
if len(arguments) != 0 {
|
||||||
|
listType = arguments[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
switch listType {
|
||||||
|
case "":
|
||||||
|
fallthrough
|
||||||
|
case "ACTIVE":
|
||||||
|
{
|
||||||
|
groups, err := h.listGroups()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
|
||||||
|
for _, v := range groups {
|
||||||
|
// TODO set high/low mark and posting status to actual values
|
||||||
|
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "NEWSGROUPS":
|
||||||
|
{
|
||||||
|
groups, err := h.listGroups()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, v := range groups {
|
||||||
|
desc := ""
|
||||||
|
if v.Description == nil {
|
||||||
|
desc = "No description"
|
||||||
|
} else {
|
||||||
|
desc = *v.Description
|
||||||
|
}
|
||||||
|
sb.Write([]byte(fmt.Sprintf("%s %s"+protocol.CRLF, v.GroupName, desc)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.Write([]byte(protocol.MultilineEnding))
|
||||||
|
|
||||||
|
return s.tconn.PrintfLine(sb.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Handle(s *Session, message string) error {
|
||||||
|
splittedMessage := strings.Split(message, " ")
|
||||||
|
for i, v := range splittedMessage {
|
||||||
|
splittedMessage[i] = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
cmdName := splittedMessage[0]
|
||||||
|
handler, ok := h.handlers[cmdName]
|
||||||
|
if !ok {
|
||||||
|
return s.tconn.PrintfLine(protocol.MessageUnknownCommand)
|
||||||
|
}
|
||||||
|
return handler(s, splittedMessage[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Refactor to "storage backend" entity
|
||||||
|
func (h *Handler) listGroups() ([]models.Group, error) {
|
||||||
|
var groups []models.Group
|
||||||
|
return groups, h.db.Select(&groups, "SELECT * FROM groups")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getArticlesCount(g models.Group) (int, error) {
|
||||||
|
var count int
|
||||||
|
return count, h.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
|
||||||
|
}
|
@ -6,17 +6,14 @@ import (
|
|||||||
"github.com/ChronosX88/yans/internal"
|
"github.com/ChronosX88/yans/internal"
|
||||||
"github.com/ChronosX88/yans/internal/common"
|
"github.com/ChronosX88/yans/internal/common"
|
||||||
"github.com/ChronosX88/yans/internal/config"
|
"github.com/ChronosX88/yans/internal/config"
|
||||||
"github.com/ChronosX88/yans/internal/models"
|
|
||||||
"github.com/ChronosX88/yans/internal/protocol"
|
"github.com/ChronosX88/yans/internal/protocol"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/textproto"
|
"sync"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -36,6 +33,9 @@ type NNTPServer struct {
|
|||||||
port int
|
port int
|
||||||
|
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
|
|
||||||
|
sessionPool map[string]*Session
|
||||||
|
sessionPoolMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
|
func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
|
||||||
@ -59,6 +59,7 @@ func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
|
|||||||
cancelFunc: cancel,
|
cancelFunc: cancel,
|
||||||
port: cfg.Port,
|
port: cfg.Port,
|
||||||
db: db,
|
db: db,
|
||||||
|
sessionPool: map[string]*Session{},
|
||||||
}
|
}
|
||||||
return ns, nil
|
return ns, nil
|
||||||
}
|
}
|
||||||
@ -81,7 +82,30 @@ func (ns *NNTPServer) Start() error {
|
|||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
log.Printf("Client %s has connected!", conn.RemoteAddr().String())
|
log.Printf("Client %s has connected!", conn.RemoteAddr().String())
|
||||||
go ns.handleNewConnection(ctx, conn)
|
|
||||||
|
id, _ := uuid.NewUUID()
|
||||||
|
closed := make(chan bool)
|
||||||
|
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.db))
|
||||||
|
ns.sessionPoolMutex.Lock()
|
||||||
|
ns.sessionPool[id.String()] = session
|
||||||
|
ns.sessionPoolMutex.Unlock()
|
||||||
|
go func(ctx context.Context, id string, closed chan bool) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
break
|
||||||
|
case _, ok := <-closed:
|
||||||
|
{
|
||||||
|
if !ok {
|
||||||
|
ns.sessionPoolMutex.Lock()
|
||||||
|
delete(ns.sessionPool, id)
|
||||||
|
ns.sessionPoolMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(ctx, id.String(), closed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -90,133 +114,6 @@ func (ns *NNTPServer) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ns *NNTPServer) handleNewConnection(ctx context.Context, conn net.Conn) {
|
|
||||||
_, err := conn.Write([]byte(protocol.MessageNNTPServiceReadyPostingProhibited + protocol.CRLF))
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tconn := textproto.NewConn(conn)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
{
|
|
||||||
message, err := tconn.ReadLine()
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
|
|
||||||
log.Printf("Client %s has diconnected!", conn.RemoteAddr().String())
|
|
||||||
} else {
|
|
||||||
log.Print(err)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("Received message from %s: %s", conn.RemoteAddr().String(), string(message))
|
|
||||||
err = ns.handleMessage(tconn, message)
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ns *NNTPServer) handleMessage(conn *textproto.Conn, msg string) error {
|
|
||||||
splittedMessage := strings.Split(msg, " ")
|
|
||||||
command := splittedMessage[0]
|
|
||||||
|
|
||||||
reply := ""
|
|
||||||
quit := false
|
|
||||||
|
|
||||||
switch command {
|
|
||||||
case protocol.CommandCapabilities:
|
|
||||||
{
|
|
||||||
reply = Capabilities.String()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
case protocol.CommandDate:
|
|
||||||
{
|
|
||||||
reply = fmt.Sprintf("111 %s", time.Now().UTC().Format("20060102150405"))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
case protocol.CommandQuit:
|
|
||||||
{
|
|
||||||
reply = protocol.MessageNNTPServiceExitsNormally
|
|
||||||
quit = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
case protocol.CommandMode:
|
|
||||||
{
|
|
||||||
if splittedMessage[1] == "READER" {
|
|
||||||
// TODO actually switch current conn to reader mode
|
|
||||||
reply = protocol.MessageReaderModePostingProhibited
|
|
||||||
} else {
|
|
||||||
reply = protocol.MessageUnknownCommand
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
case protocol.CommandList:
|
|
||||||
{
|
|
||||||
groups, err := ns.listGroups()
|
|
||||||
if err != nil {
|
|
||||||
reply = protocol.MessageErrorHappened + err.Error()
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
sb := strings.Builder{}
|
|
||||||
sb.Write([]byte(protocol.MessageListOfNewsgroupsFollows + protocol.CRLF))
|
|
||||||
if len(splittedMessage) == 1 || splittedMessage[1] == "ACTIVE" {
|
|
||||||
for _, v := range groups {
|
|
||||||
// TODO set high/low mark and posting status to actual values
|
|
||||||
sb.Write([]byte(fmt.Sprintf("%s 0 0 n"+protocol.CRLF, v.GroupName)))
|
|
||||||
}
|
|
||||||
} else if splittedMessage[1] == "NEWSGROUPS" {
|
|
||||||
for _, v := range groups {
|
|
||||||
desc := ""
|
|
||||||
if v.Description == nil {
|
|
||||||
desc = "No description"
|
|
||||||
} else {
|
|
||||||
desc = *v.Description
|
|
||||||
}
|
|
||||||
sb.Write([]byte(fmt.Sprintf("%s %s\r\n", v.GroupName, desc)))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
reply = protocol.MessageUnknownCommand
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.Write([]byte("."))
|
|
||||||
reply = sb.String()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
{
|
|
||||||
reply = protocol.MessageUnknownCommand
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := conn.PrintfLine(reply)
|
|
||||||
if quit {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ns *NNTPServer) listGroups() ([]models.Group, error) {
|
|
||||||
var groups []models.Group
|
|
||||||
return groups, ns.db.Select(&groups, "SELECT * FROM groups")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ns *NNTPServer) getArticlesCount(g models.Group) (int, error) {
|
|
||||||
var count int
|
|
||||||
return count, ns.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ns *NNTPServer) Stop() {
|
func (ns *NNTPServer) Stop() {
|
||||||
ns.cancelFunc()
|
ns.cancelFunc()
|
||||||
}
|
}
|
||||||
|
86
internal/server/session.go
Normal file
86
internal/server/session.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/ChronosX88/yans/internal/protocol"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/textproto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ctx context.Context
|
||||||
|
capabilities protocol.Capabilities
|
||||||
|
conn net.Conn
|
||||||
|
tconn *textproto.Conn
|
||||||
|
id string
|
||||||
|
closed chan<- bool
|
||||||
|
h *Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSession(ctx context.Context, conn net.Conn, caps protocol.Capabilities, id string, closed chan<- bool, handler *Handler) (*Session, error) {
|
||||||
|
var err error
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
close(closed)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tconn := textproto.NewConn(conn)
|
||||||
|
s := &Session{
|
||||||
|
ctx: ctx,
|
||||||
|
conn: conn,
|
||||||
|
tconn: tconn,
|
||||||
|
capabilities: caps,
|
||||||
|
id: id,
|
||||||
|
closed: closed,
|
||||||
|
h: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.loop()
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) loop() {
|
||||||
|
defer func() {
|
||||||
|
close(s.closed)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := s.tconn.PrintfLine(protocol.MessageReaderModePostingProhibited) // by default access mode is read-only
|
||||||
|
if err != nil {
|
||||||
|
s.conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
message, err := s.tconn.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF || err.(*net.OpError).Unwrap() == net.ErrClosed {
|
||||||
|
log.Printf("Client %s has diconnected!", s.conn.RemoteAddr().String())
|
||||||
|
} else {
|
||||||
|
log.Print(err)
|
||||||
|
s.conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Received message from %s: %s", s.conn.RemoteAddr().String(), message) // for debugging
|
||||||
|
err = s.h.Handle(s, message)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
s.tconn.PrintfLine("%s %s", protocol.MessageErrorHappened, err.Error())
|
||||||
|
s.conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user