More refactoring and beginning of server work

This commit is contained in:
Chad Retz 2019-02-25 16:28:19 -06:00
parent c79ea0e04c
commit 3a0fa8ce14
7 changed files with 352 additions and 169 deletions

View File

@ -8,7 +8,10 @@ import (
) )
type Gun struct { type Gun struct {
peers []*Peer // Never mutated, always overwritten
currentPeers []*Peer
currentPeersLock sync.RWMutex
storage Storage storage Storage
soulGen func() string soulGen func() string
peerErrorHandler func(*ErrPeer) peerErrorHandler func(*ErrPeer)
@ -16,12 +19,15 @@ type Gun struct {
myPeerID string myPeerID string
tracking Tracking tracking Tracking
messageIDListeners map[string]chan<- *MessageReceived serversCancelFn context.CancelFunc
messageIDListeners map[string]chan<- *messageReceived
messageIDListenersLock sync.RWMutex messageIDListenersLock sync.RWMutex
} }
type Config struct { type Config struct {
PeerURLs []string PeerURLs []string
Servers []Server
Storage Storage Storage Storage
SoulGen func() string SoulGen func() string
PeerErrorHandler func(*ErrPeer) PeerErrorHandler func(*ErrPeer)
@ -43,14 +49,14 @@ const DefaultOldestAllowedStorageValue = 7 * (60 * time.Minute)
func New(ctx context.Context, config Config) (*Gun, error) { func New(ctx context.Context, config Config) (*Gun, error) {
g := &Gun{ g := &Gun{
peers: make([]*Peer, len(config.PeerURLs)), currentPeers: make([]*Peer, len(config.PeerURLs)),
storage: config.Storage, storage: config.Storage,
soulGen: config.SoulGen, soulGen: config.SoulGen,
peerErrorHandler: config.PeerErrorHandler, peerErrorHandler: config.PeerErrorHandler,
peerSleepOnError: config.PeerSleepOnError, peerSleepOnError: config.PeerSleepOnError,
myPeerID: config.MyPeerID, myPeerID: config.MyPeerID,
tracking: config.Tracking, tracking: config.Tracking,
messageIDListeners: map[string]chan<- *MessageReceived{}, messageIDListeners: map[string]chan<- *messageReceived{},
} }
// Create all the peers // Create all the peers
sleepOnError := config.PeerSleepOnError sleepOnError := config.PeerSleepOnError
@ -61,13 +67,13 @@ func New(ctx context.Context, config Config) (*Gun, error) {
for i := 0; i < len(config.PeerURLs) && err == nil; i++ { for i := 0; i < len(config.PeerURLs) && err == nil; i++ {
peerURL := config.PeerURLs[i] peerURL := config.PeerURLs[i]
newConn := func() (PeerConn, error) { return NewPeerConn(ctx, peerURL) } newConn := func() (PeerConn, error) { return NewPeerConn(ctx, peerURL) }
if g.peers[i], err = newPeer(peerURL, newConn, sleepOnError); err != nil { if g.currentPeers[i], err = newPeer(peerURL, newConn, sleepOnError); err != nil {
err = fmt.Errorf("Failed connecting to peer %v: %v", peerURL, err) err = fmt.Errorf("Failed connecting to peer %v: %v", peerURL, err)
} }
} }
// If there was an error, we need to close what we did create // If there was an error, we need to close what we did create
if err != nil { if err != nil {
for _, peer := range g.peers { for _, peer := range g.currentPeers {
if peer != nil { if peer != nil {
peer.Close() peer.Close()
} }
@ -84,8 +90,12 @@ func New(ctx context.Context, config Config) (*Gun, error) {
if g.myPeerID == "" { if g.myPeerID == "" {
g.myPeerID = randString(9) g.myPeerID = randString(9)
} }
// Start receiving // Start receiving from peers
g.startReceiving() for _, peer := range g.currentPeers {
go g.startReceiving(peer)
}
// Start all the servers
go g.startServers(config.Servers)
return g, nil return g, nil
} }
@ -99,11 +109,12 @@ func (g *Gun) Scoped(ctx context.Context, key string, children ...string) *Scope
func (g *Gun) Close() error { func (g *Gun) Close() error {
var errs []error var errs []error
for _, p := range g.peers { for _, p := range g.peers() {
if err := p.Close(); err != nil { if err := p.Close(); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
} }
g.serversCancelFn()
if err := g.storage.Close(); err != nil { if err := g.storage.Close(); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -116,13 +127,41 @@ func (g *Gun) Close() error {
} }
} }
func (g *Gun) peers() []*Peer {
g.currentPeersLock.RLock()
defer g.currentPeersLock.RUnlock()
return g.currentPeers
}
func (g *Gun) addPeer(p *Peer) {
g.currentPeersLock.Lock()
defer g.currentPeersLock.Unlock()
prev := g.currentPeers
g.currentPeers = make([]*Peer, len(prev)+1)
copy(g.currentPeers, prev)
g.currentPeers[len(prev)] = p
}
func (g *Gun) removePeer(p *Peer) {
g.currentPeersLock.Lock()
defer g.currentPeersLock.Unlock()
prev := g.currentPeers
g.currentPeers = make([]*Peer, 0, len(prev))
for _, peer := range prev {
if peer != p {
g.currentPeers = append(g.currentPeers, peer)
}
}
}
func (g *Gun) send(ctx context.Context, msg *Message, ignorePeer *Peer) <-chan *ErrPeer { func (g *Gun) send(ctx context.Context, msg *Message, ignorePeer *Peer) <-chan *ErrPeer {
ch := make(chan *ErrPeer, len(g.peers)) peers := g.peers()
ch := make(chan *ErrPeer, len(peers))
// Everything async // Everything async
go func() { go func() {
defer close(ch) defer close(ch)
var wg sync.WaitGroup var wg sync.WaitGroup
for _, peer := range g.peers { for _, peer := range peers {
if peer == ignorePeer { if peer == ignorePeer {
continue continue
} }
@ -131,6 +170,9 @@ func (g *Gun) send(ctx context.Context, msg *Message, ignorePeer *Peer) <-chan *
defer wg.Done() defer wg.Done()
// Just do nothing if the peer is bad and we couldn't send // Just do nothing if the peer is bad and we couldn't send
if _, err := peer.send(ctx, msg); err != nil { if _, err := peer.send(ctx, msg); err != nil {
if !peer.reconnectSupported() {
g.removePeer(peer)
}
peerErr := &ErrPeer{err, peer} peerErr := &ErrPeer{err, peer}
go g.onPeerError(peerErr) go g.onPeerError(peerErr)
ch <- peerErr ch <- peerErr
@ -142,9 +184,7 @@ func (g *Gun) send(ctx context.Context, msg *Message, ignorePeer *Peer) <-chan *
return ch return ch
} }
func (g *Gun) startReceiving() { func (g *Gun) startReceiving(peer *Peer) {
for _, peer := range g.peers {
go func(peer *Peer) {
// TDO: some kind of overall context is probably needed // TDO: some kind of overall context is probably needed
ctx, cancelFn := context.WithCancel(context.TODO()) ctx, cancelFn := context.WithCancel(context.TODO())
defer cancelFn() defer cancelFn()
@ -155,20 +195,22 @@ func (g *Gun) startReceiving() {
if err != nil { if err != nil {
go g.onPeerError(&ErrPeer{err, peer}) go g.onPeerError(&ErrPeer{err, peer})
} }
// Always sleep at least the err duration // If can reconnect, sleep at least the err duration, otherwise remove
if peer.reconnectSupported() {
time.Sleep(g.peerSleepOnError) time.Sleep(g.peerSleepOnError)
} else {
g.removePeer(peer)
}
} else { } else {
// Go over each message and see if it needs delivering or rebroadcasting // Go over each message and see if it needs delivering or rebroadcasting
for _, msg := range msgs { for _, msg := range msgs {
g.onPeerMessage(ctx, &MessageReceived{Message: msg, Peer: peer}) g.onPeerMessage(ctx, &messageReceived{Message: msg, peer: peer})
} }
} }
} }
}(peer)
}
} }
func (g *Gun) onPeerMessage(ctx context.Context, msg *MessageReceived) { func (g *Gun) onPeerMessage(ctx context.Context, msg *messageReceived) {
// If we're tracking everything, persist all puts here. // If we're tracking everything, persist all puts here.
if g.tracking == TrackingEverything { if g.tracking == TrackingEverything {
for parentSoul, node := range msg.Put { for parentSoul, node := range msg.Put {
@ -195,17 +237,20 @@ func (g *Gun) onPeerMessage(ctx context.Context, msg *MessageReceived) {
if msg.PID == "" { if msg.PID == "" {
// This is a request, set the PID and send it back // This is a request, set the PID and send it back
msg.PID = g.myPeerID msg.PID = g.myPeerID
if _, err := msg.Peer.send(ctx, msg.Message); err != nil { if _, err := msg.peer.send(ctx, msg.Message); err != nil {
go g.onPeerError(&ErrPeer{err, msg.Peer}) go g.onPeerError(&ErrPeer{err, msg.peer})
if !msg.peer.reconnectSupported() {
g.removePeer(msg.peer)
}
} }
} else { } else {
// This is them telling us theirs // This is them telling us theirs
msg.Peer.id = msg.PID msg.peer.id = msg.PID
} }
return return
} }
// Unhandled message means rebroadcast // Unhandled message means rebroadcast
g.send(ctx, msg.Message, msg.Peer) g.send(ctx, msg.Message, msg.peer)
} }
func (g *Gun) onPeerError(err *ErrPeer) { func (g *Gun) onPeerError(err *ErrPeer) {
@ -214,7 +259,7 @@ func (g *Gun) onPeerError(err *ErrPeer) {
} }
} }
func (g *Gun) registerMessageIDListener(id string, ch chan<- *MessageReceived) { func (g *Gun) registerMessageIDListener(id string, ch chan<- *messageReceived) {
g.messageIDListenersLock.Lock() g.messageIDListenersLock.Lock()
defer g.messageIDListenersLock.Unlock() defer g.messageIDListenersLock.Unlock()
g.messageIDListeners[id] = ch g.messageIDListeners[id] = ch
@ -226,7 +271,7 @@ func (g *Gun) unregisterMessageIDListener(id string) {
delete(g.messageIDListeners, id) delete(g.messageIDListeners, id)
} }
func safeReceivedMessageSend(ch chan<- *MessageReceived, msg *MessageReceived) { func safeReceivedMessageSend(ch chan<- *messageReceived, msg *messageReceived) {
// Due to the fact that we may send on a closed channel here, we ignore the panic // Due to the fact that we may send on a closed channel here, we ignore the panic
defer func() { recover() }() defer func() { recover() }()
ch <- msg ch <- msg

View File

@ -21,7 +21,9 @@ type MessageGetRequest struct {
Field string `json:".,omitempty"` Field string `json:".,omitempty"`
} }
type MessageReceived struct { type messageReceived struct {
*Message *Message
Peer *Peer
peer *Peer
stored bool
} }

View File

@ -2,13 +2,10 @@ package gun
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"sync" "sync"
"time" "time"
"github.com/gorilla/websocket"
) )
type ErrPeer struct { type ErrPeer struct {
@ -19,7 +16,7 @@ type ErrPeer struct {
func (e *ErrPeer) Error() string { return fmt.Sprintf("Error on peer %v: %v", e.Peer, e.Err) } func (e *ErrPeer) Error() string { return fmt.Sprintf("Error on peer %v: %v", e.Peer, e.Err) }
type Peer struct { type Peer struct {
url string name string
newConn func() (PeerConn, error) newConn func() (PeerConn, error)
sleepOnErr time.Duration // TODO: would be better as backoff sleepOnErr time.Duration // TODO: would be better as backoff
id string id string
@ -29,8 +26,8 @@ type Peer struct {
connLock sync.Mutex connLock sync.Mutex
} }
func newPeer(url string, newConn func() (PeerConn, error), sleepOnErr time.Duration) (*Peer, error) { func newPeer(name string, newConn func() (PeerConn, error), sleepOnErr time.Duration) (*Peer, error) {
p := &Peer{url: url, newConn: newConn, sleepOnErr: sleepOnErr} p := &Peer{name: name, newConn: newConn, sleepOnErr: sleepOnErr}
var err error var err error
if p.connCurrent, err = newConn(); err != nil { if p.connCurrent, err = newConn(); err != nil {
return nil, err return nil, err
@ -45,14 +42,21 @@ func (p *Peer) String() string {
if p.id != "" { if p.id != "" {
id = "(id: " + p.id + ")" id = "(id: " + p.id + ")"
} }
connStatus := "connected" connStatus := "disconnected"
if p.Conn() == nil { if conn := p.Conn(); conn != nil {
connStatus = "disconnected" connStatus = "connected to " + conn.RemoteURL()
} }
return fmt.Sprintf("Peer%v %v (%v)", id, p.url, connStatus) return fmt.Sprintf("Peer%v %v (%v)", id, p.name, connStatus)
}
func (p *Peer) reconnectSupported() bool {
return p.sleepOnErr > 0
} }
func (p *Peer) reconnect() (err error) { func (p *Peer) reconnect() (err error) {
if !p.reconnectSupported() {
return fmt.Errorf("Reconnect not supported")
}
p.connLock.Lock() p.connLock.Lock()
defer p.connLock.Unlock() defer p.connLock.Unlock()
if p.connCurrent == nil && p.connBad { if p.connCurrent == nil && p.connBad {
@ -73,6 +77,10 @@ func (p *Peer) Conn() PeerConn {
} }
func (p *Peer) markConnErrored(conn PeerConn) { func (p *Peer) markConnErrored(conn PeerConn) {
if !p.reconnectSupported() {
p.Close()
return
}
p.connLock.Lock() p.connLock.Lock()
defer p.connLock.Unlock() defer p.connLock.Unlock()
if conn == p.connCurrent { if conn == p.connCurrent {
@ -91,12 +99,12 @@ func (p *Peer) send(ctx context.Context, msg *Message, moreMsgs ...*Message) (ok
// Clone them with peer "to" // Clone them with peer "to"
updatedMsg := &Message{} updatedMsg := &Message{}
*updatedMsg = *msg *updatedMsg = *msg
updatedMsg.To = p.url updatedMsg.To = conn.RemoteURL()
updatedMoreMsgs := make([]*Message, len(moreMsgs)) updatedMoreMsgs := make([]*Message, len(moreMsgs))
for i, moreMsg := range moreMsgs { for i, moreMsg := range moreMsgs {
updatedMoreMsg := &Message{} updatedMoreMsg := &Message{}
*updatedMoreMsg = *moreMsg *updatedMoreMsg = *moreMsg
updatedMoreMsg.To = p.url updatedMoreMsg.To = conn.RemoteURL()
updatedMoreMsgs[i] = updatedMoreMsg updatedMoreMsgs[i] = updatedMoreMsg
} }
if err = conn.Send(ctx, updatedMsg, updatedMoreMsgs...); err != nil { if err = conn.Send(ctx, updatedMsg, updatedMoreMsgs...); err != nil {
@ -138,8 +146,8 @@ func (p *Peer) Closed() bool {
type PeerConn interface { type PeerConn interface {
Send(ctx context.Context, msg *Message, moreMsgs ...*Message) error Send(ctx context.Context, msg *Message, moreMsgs ...*Message) error
// Chan is closed on first err, when context is closed, or when peer is closed
Receive(ctx context.Context) ([]*Message, error) Receive(ctx context.Context) ([]*Message, error)
RemoteURL() string
Close() error Close() error
} }
@ -151,10 +159,10 @@ func init() {
schemeChangedURL := &url.URL{} schemeChangedURL := &url.URL{}
*schemeChangedURL = *peerURL *schemeChangedURL = *peerURL
schemeChangedURL.Scheme = "ws" schemeChangedURL.Scheme = "ws"
return NewPeerConnWebSocket(ctx, schemeChangedURL) return DialPeerConnWebSocket(ctx, schemeChangedURL)
}, },
"ws": func(ctx context.Context, peerURL *url.URL) (PeerConn, error) { "ws": func(ctx context.Context, peerURL *url.URL) (PeerConn, error) {
return NewPeerConnWebSocket(ctx, peerURL) return DialPeerConnWebSocket(ctx, peerURL)
}, },
} }
} }
@ -168,92 +176,3 @@ func NewPeerConn(ctx context.Context, peerURL string) (PeerConn, error) {
return peerNew(ctx, parsedURL) return peerNew(ctx, parsedURL)
} }
} }
type PeerConnWebSocket struct {
Underlying *websocket.Conn
WriteLock sync.Mutex
}
func NewPeerConnWebSocket(ctx context.Context, peerUrl *url.URL) (*PeerConnWebSocket, error) {
conn, _, err := websocket.DefaultDialer.DialContext(ctx, peerUrl.String(), nil)
if err != nil {
return nil, err
}
return &PeerConnWebSocket{Underlying: conn}, nil
}
func (p *PeerConnWebSocket) Send(ctx context.Context, msg *Message, moreMsgs ...*Message) error {
// If there are more, send all as an array of JSON strings, otherwise just the msg
var toWrite interface{}
if len(moreMsgs) == 0 {
toWrite = msg
} else {
b, err := json.Marshal(msg)
if err != nil {
return err
}
msgs := []string{string(b)}
for _, nextMsg := range moreMsgs {
if b, err = json.Marshal(nextMsg); err != nil {
return err
}
msgs = append(msgs, string(b))
}
toWrite = msgs
}
// Send async so we can wait on context
errCh := make(chan error, 1)
go func() {
p.WriteLock.Lock()
defer p.WriteLock.Unlock()
errCh <- p.Underlying.WriteJSON(toWrite)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (p *PeerConnWebSocket) Receive(ctx context.Context) ([]*Message, error) {
bytsCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
if _, b, err := p.Underlying.ReadMessage(); err != nil {
errCh <- err
} else {
bytsCh <- b
}
}()
select {
case err := <-errCh:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
case byts := <-bytsCh:
// If it's a JSON array, it means it's an array of JSON strings, otherwise it's one message
if byts[0] != '[' {
var msg Message
if err := json.Unmarshal(byts, &msg); err != nil {
return nil, err
}
return []*Message{&msg}, nil
}
var jsonStrs []string
if err := json.Unmarshal(byts, &jsonStrs); err != nil {
return nil, err
}
msgs := make([]*Message, len(jsonStrs))
for i, jsonStr := range jsonStrs {
if err := json.Unmarshal([]byte(jsonStr), &(msgs[i])); err != nil {
return nil, err
}
}
return msgs, nil
}
}
func (p *PeerConnWebSocket) Close() error {
return p.Underlying.Close()
}

View File

@ -5,6 +5,24 @@ import (
"fmt" "fmt"
) )
type fetchResultListener struct {
id string
results chan *FetchResult
receivedMessages chan *messageReceived
}
type FetchResult struct {
// This can be a context error on cancelation
Err error
Field string
// Nil if the value doesn't exist, exists and is nil, or there's an error
Value Value
State State // This can be 0 for errors or top-level value relations
ValueExists bool
// Nil when local and sometimes on error
Peer *Peer
}
func (s *Scoped) FetchOne(ctx context.Context) *FetchResult { func (s *Scoped) FetchOne(ctx context.Context) *FetchResult {
// Try local before remote // Try local before remote
if r := s.FetchOneLocal(ctx); r.Err != nil || r.ValueExists { if r := s.FetchOneLocal(ctx); r.Err != nil || r.ValueExists {
@ -83,7 +101,7 @@ func (s *Scoped) fetchRemote(ctx context.Context, ch chan *FetchResult) {
// Make a chan to listen for received messages and link it to // Make a chan to listen for received messages and link it to
// the given one so we can turn it "off". Off will close this // the given one so we can turn it "off". Off will close this
// chan. // chan.
msgCh := make(chan *MessageReceived) msgCh := make(chan *messageReceived)
s.fetchResultListenersLock.Lock() s.fetchResultListenersLock.Lock()
s.fetchResultListeners[ch] = &fetchResultListener{req.ID, ch, msgCh} s.fetchResultListeners[ch] = &fetchResultListener{req.ID, ch, msgCh}
s.fetchResultListenersLock.Unlock() s.fetchResultListenersLock.Unlock()
@ -104,7 +122,7 @@ func (s *Scoped) fetchRemote(ctx context.Context, ch chan *FetchResult) {
if !ok { if !ok {
return return
} }
r := &FetchResult{Field: s.field, Peer: msg.Peer} r := &FetchResult{Field: s.field, Peer: msg.peer}
// We asked for a single field, should only get that field or it doesn't exist // We asked for a single field, should only get that field or it doesn't exist
if msg.Err != "" { if msg.Err != "" {
r.Err = fmt.Errorf("Remote error: %v", msg.Err) r.Err = fmt.Errorf("Remote error: %v", msg.Err)
@ -166,21 +184,3 @@ func safeFetchResultSend(ch chan<- *FetchResult, r *FetchResult) {
defer func() { recover() }() defer func() { recover() }()
ch <- r ch <- r
} }
type fetchResultListener struct {
id string
results chan *FetchResult
receivedMessages chan *MessageReceived
}
type FetchResult struct {
// This can be a context error on cancelation
Err error
Field string
// Nil if the value doesn't exist, exists and is nil, or there's an error
Value Value
State State // This can be 0 for errors or top-level value relations
ValueExists bool
// Nil when local and sometimes on error
Peer *Peer
}

View File

@ -8,7 +8,7 @@ import (
type putResultListener struct { type putResultListener struct {
id string id string
results chan *PutResult results chan *PutResult
receivedMessages chan *MessageReceived receivedMessages chan *messageReceived
} }
type PutResult struct { type PutResult struct {
@ -120,7 +120,7 @@ func (s *Scoped) Put(ctx context.Context, val Value, opts ...PutOption) <-chan *
Values: map[string]Value{s.field: val}, Values: map[string]Value{s.field: val},
} }
// Make a msg chan and register it to listen for acks // Make a msg chan and register it to listen for acks
msgCh := make(chan *MessageReceived) msgCh := make(chan *messageReceived)
s.putResultListenersLock.Lock() s.putResultListenersLock.Lock()
s.putResultListeners[ch] = &putResultListener{req.ID, ch, msgCh} s.putResultListeners[ch] = &putResultListener{req.ID, ch, msgCh}
s.putResultListenersLock.Unlock() s.putResultListenersLock.Unlock()
@ -137,7 +137,7 @@ func (s *Scoped) Put(ctx context.Context, val Value, opts ...PutOption) <-chan *
if !ok { if !ok {
return return
} }
r := &PutResult{Peer: msg.Peer} r := &PutResult{Peer: msg.peer}
if msg.Err != "" { if msg.Err != "" {
r.Err = fmt.Errorf("Remote error: %v", msg.Err) r.Err = fmt.Errorf("Remote error: %v", msg.Err)
} else if msg.OK != 1 { } else if msg.OK != 1 {

55
gun/server.go Normal file
View File

@ -0,0 +1,55 @@
package gun
import (
"context"
)
type Server interface {
Serve() error // Hangs forever
Accept() (PeerConn, error)
Close() error
}
func (g *Gun) startServers(servers []Server) {
ctx := context.Background()
ctx, g.serversCancelFn = context.WithCancel(ctx)
for _, server := range servers {
// TODO: log error?
go g.serve(ctx, server)
}
}
func (g *Gun) serve(ctx context.Context, s Server) error {
errCh := make(chan error, 1)
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
// Start the server
go func() { errCh <- s.Serve() }()
defer s.Close()
// Accept connections and break off
go func() {
if conn, err := s.Accept(); err == nil {
// TODO: log error (for accept and handle)?
go g.onNewPeerConn(ctx, conn)
}
}()
// Wait for server close or context close
select {
case err := <-errCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (g *Gun) onNewPeerConn(ctx context.Context, conn PeerConn) error {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
defer conn.Close()
// We always send a DAM req first
if err := conn.Send(ctx, &Message{DAM: "?"}); err != nil {
return err
}
// Now add the connection to Gun
panic("TODO")
}

162
gun/websocket.go Normal file
View File

@ -0,0 +1,162 @@
package gun
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sync"
"github.com/gorilla/websocket"
)
type serverWebSocket struct {
server *http.Server
acceptCh chan *websocket.Conn
acceptCtx context.Context
acceptCancelFn context.CancelFunc
serveErrCh chan error
}
func NewServerWebSocket(server *http.Server, upgrader *websocket.Upgrader) Server {
if upgrader == nil {
upgrader = &websocket.Upgrader{}
}
s := &serverWebSocket{
server: server,
acceptCh: make(chan *websocket.Conn),
serveErrCh: make(chan error, 1),
}
// Setup the accepter
s.acceptCtx, s.acceptCancelFn = context.WithCancel(context.Background())
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
if server.ErrorLog != nil {
server.ErrorLog.Printf("Failed upgrading websocket: %v", err)
}
return
}
select {
case <-s.acceptCtx.Done():
case s.acceptCh <- conn:
}
})
return s
}
func (s *serverWebSocket) Serve() error {
return s.server.ListenAndServe()
}
func (s *serverWebSocket) Accept() (PeerConn, error) {
select {
case <-s.acceptCtx.Done():
return nil, http.ErrServerClosed
case conn := <-s.acceptCh:
return NewPeerConnWebSocket(conn), nil
}
}
func (s *serverWebSocket) Close() error {
s.acceptCancelFn()
return s.server.Close()
}
type PeerConnWebSocket struct {
Underlying *websocket.Conn
WriteLock sync.Mutex
}
func DialPeerConnWebSocket(ctx context.Context, peerUrl *url.URL) (*PeerConnWebSocket, error) {
conn, _, err := websocket.DefaultDialer.DialContext(ctx, peerUrl.String(), nil)
if err != nil {
return nil, err
}
return NewPeerConnWebSocket(conn), nil
}
func NewPeerConnWebSocket(underlying *websocket.Conn) *PeerConnWebSocket {
return &PeerConnWebSocket{Underlying: underlying}
}
func (p *PeerConnWebSocket) Send(ctx context.Context, msg *Message, moreMsgs ...*Message) error {
// If there are more, send all as an array of JSON strings, otherwise just the msg
var toWrite interface{}
if len(moreMsgs) == 0 {
toWrite = msg
} else {
b, err := json.Marshal(msg)
if err != nil {
return err
}
msgs := []string{string(b)}
for _, nextMsg := range moreMsgs {
if b, err = json.Marshal(nextMsg); err != nil {
return err
}
msgs = append(msgs, string(b))
}
toWrite = msgs
}
// Send async so we can wait on context
errCh := make(chan error, 1)
go func() {
p.WriteLock.Lock()
defer p.WriteLock.Unlock()
errCh <- p.Underlying.WriteJSON(toWrite)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (p *PeerConnWebSocket) Receive(ctx context.Context) ([]*Message, error) {
bytsCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
if _, b, err := p.Underlying.ReadMessage(); err != nil {
errCh <- err
} else {
bytsCh <- b
}
}()
select {
case err := <-errCh:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
case byts := <-bytsCh:
// If it's a JSON array, it means it's an array of JSON strings, otherwise it's one message
if byts[0] != '[' {
var msg Message
if err := json.Unmarshal(byts, &msg); err != nil {
return nil, err
}
return []*Message{&msg}, nil
}
var jsonStrs []string
if err := json.Unmarshal(byts, &jsonStrs); err != nil {
return nil, err
}
msgs := make([]*Message, len(jsonStrs))
for i, jsonStr := range jsonStrs {
if err := json.Unmarshal([]byte(jsonStr), &(msgs[i])); err != nil {
return nil, err
}
}
return msgs, nil
}
}
func (p *PeerConnWebSocket) RemoteURL() string {
return fmt.Sprintf("http://%v", p.Underlying.RemoteAddr())
}
func (p *PeerConnWebSocket) Close() error {
return p.Underlying.Close()
}