Make the storage backend modular

This commit is contained in:
ChronosX88 2022-01-19 22:51:08 +03:00
parent 99061051e7
commit 3933abd40b
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A
9 changed files with 118 additions and 49 deletions

View File

@ -30,11 +30,11 @@ func main() {
log.Printf("Starting %s...", common.ServerName) log.Printf("Starting %s...", common.ServerName)
ns, err := server.NewNNTPServer(cfg) ns, err := server.NewNNTPServer(cfg)
if err != nil { if err != nil {
log.Fatal(err) log.Fatalf("Error occurred while starting the server: %s", err)
} }
if err := ns.Start(); err != nil { if err := ns.Start(); err != nil {
log.Fatal(err) log.Fatalf("Error occurred while starting the server: %s", err)
} }
log.Printf("%s has been successfully started!", common.ServerName) log.Printf("%s has been successfully started!", common.ServerName)
log.Printf("Version: %s", common.ServerVersion) log.Printf("Version: %s", common.ServerVersion)

6
config.sample.toml Normal file
View File

@ -0,0 +1,6 @@
address = "localhost"
port = 1119
backend_type = "sqlite"
[sqlite]
path = "yans.db"

View File

@ -0,0 +1,47 @@
package sqlite
import (
"embed"
"github.com/ChronosX88/yans/internal/config"
"github.com/ChronosX88/yans/internal/models"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/pressly/goose/v3"
)
//go:embed migrations/*.sql
var migrations embed.FS
type SQLiteBackend struct {
db *sqlx.DB
}
func NewSQLiteBackend(cfg config.SQLiteBackendConfig) (*SQLiteBackend, error) {
db, err := sqlx.Open("sqlite3", cfg.Path)
if err != nil {
return nil, err
}
goose.SetBaseFS(migrations)
if err := goose.SetDialect("sqlite3"); err != nil {
return nil, err
}
if err := goose.Up(db.DB, "migrations"); err != nil {
return nil, err
}
return &SQLiteBackend{
db: db,
}, nil
}
func (sb *SQLiteBackend) ListGroups() ([]models.Group, error) {
var groups []models.Group
return groups, sb.db.Select(&groups, "SELECT * FROM groups")
}
func (sb *SQLiteBackend) GetArticlesCount(g models.Group) (int, error) {
var count int
return count, sb.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
}

View File

@ -0,0 +1,12 @@
package backend
import "github.com/ChronosX88/yans/internal/models"
const (
SupportedBackendList = "sqlite"
)
type StorageBackend interface {
ListGroups() ([]models.Group, error)
GetArticlesCount(g models.Group) (int, error)
}

View File

@ -5,9 +5,19 @@ import (
"os" "os"
) )
const (
SQLiteBackendType = "sqlite"
)
type Config struct { type Config struct {
Port int Address string `toml:"address"`
DatabasePath string Port int `toml:"port"`
BackendType string `toml:"backend_type"`
SQLite SQLiteBackendConfig `toml:"sqlite"`
}
type SQLiteBackendConfig struct {
Path string `toml:"path"`
} }
func ParseConfig(path string) (Config, error) { func ParseConfig(path string) (Config, error) {

View File

@ -1,6 +0,0 @@
package internal
import "embed"
//go:embed migrations/*.sql
var Migrations embed.FS

View File

@ -2,21 +2,20 @@ package server
import ( import (
"fmt" "fmt"
"github.com/ChronosX88/yans/internal/models" "github.com/ChronosX88/yans/internal/backend"
"github.com/ChronosX88/yans/internal/protocol" "github.com/ChronosX88/yans/internal/protocol"
"github.com/jmoiron/sqlx"
"strings" "strings"
"time" "time"
) )
type Handler struct { type Handler struct {
handlers map[string]func(s *Session, arguments []string) error handlers map[string]func(s *Session, arguments []string) error
db *sqlx.DB backend backend.StorageBackend
} }
func NewHandler(db *sqlx.DB) *Handler { func NewHandler(b backend.StorageBackend) *Handler {
h := &Handler{} h := &Handler{}
h.db = db h.backend = b
h.handlers = map[string]func(s *Session, arguments []string) error{ h.handlers = map[string]func(s *Session, arguments []string) error{
protocol.CommandCapabilities: h.handleCapabilities, protocol.CommandCapabilities: h.handleCapabilities,
protocol.CommandDate: h.handleDate, protocol.CommandDate: h.handleDate,
@ -54,7 +53,7 @@ func (h *Handler) handleList(s *Session, arguments []string) error {
fallthrough fallthrough
case "ACTIVE": case "ACTIVE":
{ {
groups, err := h.listGroups() groups, err := h.backend.ListGroups()
if err != nil { if err != nil {
return err return err
} }
@ -66,7 +65,7 @@ func (h *Handler) handleList(s *Session, arguments []string) error {
} }
case "NEWSGROUPS": case "NEWSGROUPS":
{ {
groups, err := h.listGroups() groups, err := h.backend.ListGroups()
if err != nil { if err != nil {
return err return err
} }
@ -117,14 +116,3 @@ func (h *Handler) Handle(s *Session, message string) error {
} }
return handler(s, splittedMessage[1:]) return handler(s, splittedMessage[1:])
} }
// TODO Refactor to "storage backend" entity
func (h *Handler) listGroups() ([]models.Group, error) {
var groups []models.Group
return groups, h.db.Select(&groups, "SELECT * FROM groups")
}
func (h *Handler) getArticlesCount(g models.Group) (int, error) {
var count int
return count, h.db.Select(&count, "SELECT COUNT(*) FROM articles_to_groups WHERE group_id = ?", g.ID)
}

View File

@ -3,14 +3,12 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/ChronosX88/yans/internal" "github.com/ChronosX88/yans/internal/backend"
"github.com/ChronosX88/yans/internal/backend/sqlite"
"github.com/ChronosX88/yans/internal/common" "github.com/ChronosX88/yans/internal/common"
"github.com/ChronosX88/yans/internal/config" "github.com/ChronosX88/yans/internal/config"
"github.com/ChronosX88/yans/internal/protocol" "github.com/ChronosX88/yans/internal/protocol"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/pressly/goose/v3"
"log" "log"
"net" "net"
"sync" "sync"
@ -28,47 +26,61 @@ type NNTPServer struct {
ctx context.Context ctx context.Context
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
ln net.Listener ln net.Listener
port int cfg config.Config
db *sqlx.DB backend backend.StorageBackend
sessionPool map[string]*Session sessionPool map[string]*Session
sessionPoolMutex sync.Mutex sessionPoolMutex sync.Mutex
} }
func NewNNTPServer(cfg config.Config) (*NNTPServer, error) { func NewNNTPServer(cfg config.Config) (*NNTPServer, error) {
db, err := sqlx.Open("sqlite3", cfg.DatabasePath) b, err := initBackend(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
goose.SetBaseFS(internal.Migrations)
if err := goose.SetDialect("sqlite3"); err != nil {
return nil, err
}
if err := goose.Up(db.DB, "migrations"); err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
ns := &NNTPServer{ ns := &NNTPServer{
ctx: ctx, ctx: ctx,
cancelFunc: cancel, cancelFunc: cancel,
port: cfg.Port, cfg: cfg,
db: db, backend: b,
sessionPool: map[string]*Session{}, sessionPool: map[string]*Session{},
} }
return ns, nil return ns, nil
} }
func initBackend(cfg config.Config) (backend.StorageBackend, error) {
var sb backend.StorageBackend
switch cfg.BackendType {
case config.SQLiteBackendType:
{
sqliteBackend, err := sqlite.NewSQLiteBackend(cfg.SQLite)
if err != nil {
return nil, err
}
sb = sqliteBackend
}
default:
{
return nil, fmt.Errorf("invalid backend type, supported backends: %s", backend.SupportedBackendList)
}
}
return sb, nil
}
func (ns *NNTPServer) Start() error { func (ns *NNTPServer) Start() error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", ns.port)) address := fmt.Sprintf("%s:%d", ns.cfg.Address, ns.cfg.Port)
ln, err := net.Listen("tcp", address)
if err != nil { if err != nil {
return err return err
} }
log.Printf("Listening on %s...", address)
go func(ctx context.Context) { go func(ctx context.Context) {
for { for {
select { select {
@ -84,7 +96,7 @@ func (ns *NNTPServer) Start() error {
id, _ := uuid.NewUUID() id, _ := uuid.NewUUID()
closed := make(chan bool) closed := make(chan bool)
session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.db)) session, err := NewSession(ctx, conn, Capabilities, id.String(), closed, NewHandler(ns.backend))
ns.sessionPoolMutex.Lock() ns.sessionPoolMutex.Lock()
ns.sessionPool[id.String()] = session ns.sessionPool[id.String()] = session
ns.sessionPoolMutex.Unlock() ns.sessionPoolMutex.Unlock()