internal/langserver/langserver.go (149 lines of code) (raw):
package langserver
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"github.com/Azure/azapi-lsp/internal/langserver/session"
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/channel"
"github.com/creachadair/jrpc2/server"
)
type langServer struct {
srvCtx context.Context
logger *log.Logger
srvOptions *jrpc2.ServerOptions
newSession session.SessionFactory
}
type ctxReqConcurrency struct{}
func NewLangServer(srvCtx context.Context, sf session.SessionFactory) *langServer {
concurrency, ok := requestConcurrencyFromCtx(srvCtx)
if !ok {
concurrency = DefaultConcurrency()
}
opts := &jrpc2.ServerOptions{
AllowPush: true,
Concurrency: concurrency,
}
return &langServer{
srvCtx: srvCtx,
logger: log.New(io.Discard, "", 0),
srvOptions: opts,
newSession: sf,
}
}
func WithRequestConcurrency(parent context.Context, concurrency int) context.Context {
return context.WithValue(parent, ctxReqConcurrency{}, concurrency)
}
func requestConcurrencyFromCtx(ctx context.Context) (int, bool) {
c, ok := ctx.Value(ctxReqConcurrency{}).(int)
return c, ok
}
func DefaultConcurrency() int {
cpu := runtime.NumCPU()
// Cap concurrency on powerful machines
// to leave some capacity for module ops
// and other application
if cpu >= 4 {
return cpu / 2
}
return cpu
}
func (ls *langServer) SetLogger(logger *log.Logger) {
ls.srvOptions.Logger = jrpc2.StdLogger(logger)
ls.srvOptions.RPCLog = &rpcLogger{logger}
ls.logger = logger
}
func (ls *langServer) newService() server.Service {
svc := ls.newSession(ls.srvCtx)
svc.SetLogger(ls.logger)
return svc
}
func (ls *langServer) startServer(reader io.Reader, writer io.WriteCloser) (*singleServer, error) {
srv, err := Server(ls.newService(), ls.srvOptions)
if err != nil {
return nil, err
}
srv.Start(channel.LSP(reader, writer))
return srv, nil
}
func (ls *langServer) StartAndWait(reader io.Reader, writer io.WriteCloser) error {
srv, err := ls.startServer(reader, writer)
if err != nil {
return err
}
ls.logger.Printf("Starting server (pid %d; concurrency: %d) ...",
os.Getpid(), ls.srvOptions.Concurrency)
// Wrap waiter with a context so that we can cancel it here
// after the service is cancelled (and srv.Wait returns)
ctx, cancelFunc := context.WithCancel(ls.srvCtx)
go func() {
srv.Wait()
cancelFunc()
}()
select { //nolint
case <-ctx.Done():
ls.logger.Printf("Stopping server (pid %d) ...", os.Getpid())
srv.Stop()
}
ls.logger.Printf("Server (pid %d) stopped.", os.Getpid())
return nil
}
func (ls *langServer) StartTCP(address string) error {
ls.logger.Printf("Starting TCP server (pid %d; concurrency: %d) at %q ...",
os.Getpid(), ls.srvOptions.Concurrency, address)
lst, err := net.Listen("tcp", address)
if err != nil {
return fmt.Errorf("TCP Server failed to start: %s", err)
}
ls.logger.Printf("TCP server running at %q", lst.Addr())
accepter := server.NetAccepter(lst, channel.LSP)
go func() {
ls.logger.Println("Starting loop server ...")
err = server.Loop(accepter, ls.newService, &server.LoopOptions{
ServerOptions: ls.srvOptions,
})
if err != nil {
ls.logger.Printf("Loop server failed to start: %s", err)
}
}()
select { //nolint
case <-ls.srvCtx.Done():
ls.logger.Printf("Stopping TCP server (pid %d) ...", os.Getpid())
err = lst.Close()
if err != nil {
ls.logger.Printf("TCP server (pid %d) failed to stop: %s", os.Getpid(), err)
return err
}
}
ls.logger.Printf("TCP server (pid %d) stopped.", os.Getpid())
return nil
}
// singleServer is a wrapper around jrpc2.NewServer providing support
// for server.Service (Assigner/Finish interface)
type singleServer struct {
srv *jrpc2.Server
finishFunc func(jrpc2.ServerStatus)
}
func Server(svc server.Service, opts *jrpc2.ServerOptions) (*singleServer, error) {
assigner, err := svc.Assigner()
if err != nil {
return nil, err
}
return &singleServer{
srv: jrpc2.NewServer(assigner, opts),
finishFunc: func(status jrpc2.ServerStatus) {
svc.Finish(assigner, status)
},
}, nil
}
func (ss *singleServer) Start(ch channel.Channel) {
ss.srv = ss.srv.Start(ch)
}
func (ss *singleServer) StartAndWait(ch channel.Channel) {
ss.Start(ch)
ss.Wait()
}
func (ss *singleServer) Wait() {
status := ss.srv.WaitStatus()
ss.finishFunc(status)
}
func (ss *singleServer) Stop() {
ss.srv.Stop()
}