// +build !js
package websocket
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"strings"
"nhooyr.io/websocket/internal/errd"
)
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection when c.Subprotocol() == "".
Subprotocols []string
// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
//
// You probably want to use OriginPatterns instead.
InsecureSkipVerify bool
// OriginPatterns lists the host patterns for authorized origins.
// The request host is always authorized.
// Use this to enable cross origin WebSockets.
//
// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
// In such a case, example.com is the origin and chat.example.com is the request host.
// One would set this field to []string{"example.com"} to authorize example.com to connect.
//
// Each pattern is matched case insensitively against the request origin host
// with filepath.Match.
// See https://golang.org/pkg/path/filepath/#Match
//
// Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks.
//
// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
// to bring attention to the danger of such a setting.
OriginPatterns []string
// CompressionMode controls the compression mode.
// Defaults to CompressionNoContextTakeover.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
}
// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to a WebSocket.
//
// Accept will not allow cross origin requests by default.
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
//
// Accept will write a response to w on all errors.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts)
}
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")
if opts == nil {
opts = &AcceptOptions{}
}
opts = &*opts
errCode, err := verifyClientRequest(w, r)
if err != nil {
http.Error(w, err.Error(), errCode)
return nil, err
}
if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil {
if errors.Is(err, filepath.ErrBadPattern) {
log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden))
}
http.Error(w, err.Error(), http.StatusForbidden)
return nil, err
}
}
hj, ok := w.(http.Hijacker)
if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
return nil, err
}
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
subproto := selectSubprotocol(r, opts.Subprotocols)
if subproto != "" {
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}
copts, err := acceptCompression(r, w, opts.CompressionMode)
if err != nil {
return nil, err
}
w.WriteHeader(http.StatusSwitchingProtocols)
// See https://github.com/nhooyr/websocket/issues/166
if ginWriter, ok := w.(interface {
WriteHeaderNow()
}); ok {
ginWriter.WriteHeaderNow()
}
netConn, brw, err := hj.Hijack()
if err != nil {
err = fmt.Errorf("failed to hijack connection: %w", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil, err
}
// https://github.com/golang/go/issues/32314
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
return newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
br: brw.Reader,
bw: brw.Writer,
}), nil
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
if !r.ProtoAtLeast(1, 1) {
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
}
if r.Method != "GET" {
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
if r.Header.Get("Sec-WebSocket-Key") == "" {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
return 0, nil
}
func authenticateOrigin(r *http.Request, originHosts []string) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(r.Host, u.Host) {
return nil
}
for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host)
if err != nil {
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
}
if matched {
return nil
}
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
}
func match(pattern, s string) (bool, error) {
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
}
func selectSubprotocol(r *http.Request, subprotocols []string) string {
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
for _, sp := range subprotocols {
for _, cp := range cps {
if strings.EqualFold(sp, cp) {
return cp
}
}
}
return ""
}
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
if mode == CompressionDisabled {
return nil, nil
}
for _, ext := range websocketExtensions(r.Header) {
switch ext.name {
case "permessage-deflate":
return acceptDeflate(w, ext, mode)
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
// case "x-webkit-deflate-frame":
// return acceptWebkitDeflate(w, ext, mode)
}
}
return nil, nil
}
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
continue
}
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
copts.setHeader(w.Header())
return copts, nil
}
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
// The peer must explicitly request it.
copts.serverNoContextTakeover = false
for _, p := range ext.params {
if p == "no_context_takeover" {
copts.serverNoContextTakeover = true
continue
}
// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
// of ignoring it as the draft spec is unclear. It says the server can ignore it
// but the server has no way of signalling to the client it was ignored as the parameters
// are set one way.
// Thus us ignoring it would make the client think we understood it which would cause issues.
// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
//
// Either way, we're only implementing this for webkit which never sends the max_window_bits
// parameter so we don't need to worry about it.
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
s := "x-webkit-deflate-frame"
if copts.clientNoContextTakeover {
s += "; no_context_takeover"
}
w.Header().Set("Sec-WebSocket-Extensions", s)
return copts, nil
}
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
for _, t := range headerTokens(h, key) {
if strings.EqualFold(t, token) {
return true
}
}
return false
}
type websocketExtension struct {
name string
params []string
}
func websocketExtensions(h http.Header) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
for _, extStr := range extStrs {
if extStr == "" {
continue
}
vals := strings.Split(extStr, ";")
for i := range vals {
vals[i] = strings.TrimSpace(vals[i])
}
e := websocketExtension{
name: vals[0],
params: vals[1:],
}
exts = append(exts, e)
}
return exts
}
func headerTokens(h http.Header, key string) []string {
key = textproto.CanonicalMIMEHeaderKey(key)
var tokens []string
for _, v := range h[key] {
v = strings.TrimSpace(v)
for _, t := range strings.Split(v, ",") {
t = strings.TrimSpace(t)
tokens = append(tokens, t)
}
}
return tokens
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
package websocket
import (
"errors"
"fmt"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// +build !js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"log"
"time"
"nhooyr.io/websocket/internal/errd"
)
// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
return c.closeHandshake(code, reason)
}
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
writeErr := c.writeClose(code, reason)
closeHandshakeErr := c.waitCloseHandshake()
if writeErr != nil {
return writeErr
}
if CloseStatus(closeHandshakeErr) == -1 {
return closeHandshakeErr
}
return nil
}
var errAlreadyWroteClose = errors.New("already wrote close")
func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
wroteClose := c.wroteClose
c.wroteClose = true
c.closeMu.Unlock()
if wroteClose {
return errAlreadyWroteClose
}
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var marshalErr error
if ce.Code != StatusNoStatusRcvd {
p, marshalErr = ce.bytes()
if marshalErr != nil {
log.Printf("websocket: %v", marshalErr)
}
}
writeErr := c.writeControl(context.Background(), opClose, p)
if CloseStatus(writeErr) != -1 {
// Not a real error if it's due to a close frame being received.
writeErr = nil
}
// We do this after in case there was an error writing the close frame.
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
if marshalErr != nil {
return marshalErr
}
return writeErr
}
func (c *Conn) waitCloseHandshake() error {
defer c.close(nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()
if c.readCloseFrameErr != nil {
return c.readCloseFrameErr
}
for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}
for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}
const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
func (c *Conn) setCloseErr(err error) {
c.closeMu.Lock()
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}
func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// +build !js
package websocket
import (
"io"
"net/http"
"sync"
"github.com/klauspost/compress/flate"
)
func (m CompressionMode) opts() *compressionOptions {
return &compressionOptions{
clientNoContextTakeover: m == CompressionNoContextTakeover,
serverNoContextTakeover: m == CompressionNoContextTakeover,
}
}
type compressionOptions struct {
clientNoContextTakeover bool
serverNoContextTakeover bool
}
func (copts *compressionOptions) setHeader(h http.Header) {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
h.Set("Sec-WebSocket-Extensions", s)
}
// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to return more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct {
w io.Writer
tail []byte
}
func (tw *trimLastFourBytesWriter) reset() {
if tw != nil && tw.tail != nil {
tw.tail = tw.tail[:0]
}
}
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
if tw.tail == nil {
tw.tail = make([]byte, 0, 4)
}
extra := len(tw.tail) + len(p) - 4
if extra <= 0 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Now we need to write as many extra bytes as we can from the previous tail.
if extra > len(tw.tail) {
extra = len(tw.tail)
}
if extra > 0 {
_, err := tw.w.Write(tw.tail[:extra])
if err != nil {
return 0, err
}
// Shift remaining bytes in tail over.
n := copy(tw.tail, tw.tail[extra:])
tw.tail = tw.tail[:n]
}
// If p is less than or equal to 4 bytes,
// all of it is is part of the tail.
if len(p) <= 4 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Otherwise, only the last 4 bytes are.
tw.tail = append(tw.tail, p[len(p)-4:]...)
p = p[:len(p)-4]
n, err := tw.w.Write(p)
return n + 4, err
}
var flateReaderPool sync.Pool
func getFlateReader(r io.Reader, dict []byte) io.Reader {
fr, ok := flateReaderPool.Get().(io.Reader)
if !ok {
return flate.NewReaderDict(r, dict)
}
fr.(flate.Resetter).Reset(r, dict)
return fr
}
func putFlateReader(fr io.Reader) {
flateReaderPool.Put(fr)
}
type slidingWindow struct {
buf []byte
}
var swPoolMu sync.RWMutex
var swPool = map[int]*sync.Pool{}
func slidingWindowPool(n int) *sync.Pool {
swPoolMu.RLock()
p, ok := swPool[n]
swPoolMu.RUnlock()
if ok {
return p
}
p = &sync.Pool{}
swPoolMu.Lock()
swPool[n] = p
swPoolMu.Unlock()
return p
}
func (sw *slidingWindow) init(n int) {
if sw.buf != nil {
return
}
if n == 0 {
n = 32768
}
p := slidingWindowPool(n)
buf, ok := p.Get().([]byte)
if ok {
sw.buf = buf[:0]
} else {
sw.buf = make([]byte, 0, n)
}
}
func (sw *slidingWindow) close() {
if sw.buf == nil {
return
}
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw.buf)
swPoolMu.Unlock()
sw.buf = nil
}
func (sw *slidingWindow) write(p []byte) {
if len(p) >= cap(sw.buf) {
sw.buf = sw.buf[:cap(sw.buf)]
p = p[len(p)-cap(sw.buf):]
copy(sw.buf, p)
return
}
left := cap(sw.buf) - len(sw.buf)
if left < len(p) {
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
spaceNeeded := len(p) - left
copy(sw.buf, sw.buf[spaceNeeded:])
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
}
sw.buf = append(sw.buf, p...)
}
// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
type Conn struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
readCloseFrameErr error
// Write state.
msgWriterState *msgWriterState
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
}
type connConfig struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
c := &Conn{
subprotocol: cfg.subprotocol,
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
flateThreshold: cfg.flateThreshold,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
}
c.readMu = newMu(c)
c.writeFrameMu = newMu(c)
c.msgReader = newMsgReader(c)
c.msgWriterState = newMsgWriterState(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
})
go c.timeoutLoop()
return c
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close(err error) {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return
}
c.setCloseErrLocked(err)
close(c.closed)
runtime.SetFinalizer(c, nil)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.rwc.Close()
go func() {
c.msgWriterState.close()
c.msgReader.close()
}()
}
func (c *Conn) timeoutLoop() {
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
return
}
}
}
func (c *Conn) flate() bool {
return c.copts != nil
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1)
err := c.ping(ctx, strconv.Itoa(int(p)))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context, p string) error {
pong := make(chan struct{}, 1)
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return c.closeErr
case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
case <-pong:
return nil
}
}
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return m.c.closeErr
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return m.c.closeErr
default:
}
return nil
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
// +build !js
package websocket
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"time"
"nhooyr.io/websocket/internal/errd"
)
// DialOptions represents Dial's options.
type DialOptions struct {
// HTTPClient is used for the connection.
// Its Transport must return writable bodies for WebSocket handshakes.
// http.Transport does beginning with Go 1.12.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string
// CompressionMode controls the compression mode.
// Defaults to CompressionNoContextTakeover.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
}
// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
//
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
//
// URLs with http/https schemes will work and are interpreted as ws/wss.
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
return dial(ctx, u, opts, nil)
}
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
defer errd.Wrap(&err, "failed to WebSocket dial")
if opts == nil {
opts = &DialOptions{}
}
opts = &*opts
if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient
} else if opts.HTTPClient.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
defer cancel()
newClient := *opts.HTTPClient
newClient.Timeout = 0
opts.HTTPClient = &newClient
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
respBody := resp.Body
resp.Body = nil
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
})
defer timer.Stop()
b, _ := ioutil.ReadAll(r)
respBody.Close()
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
}()
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
rwc, ok := respBody.(io.ReadWriteCloser)
if !ok {
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
}
return newConn(connConfig{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
rwc: rwc,
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
}
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
u, err := url.Parse(urls)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
case "http", "https":
default:
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
copts.setHeader(req.Header)
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send handshake request: %w", err)
}
return resp, nil
}
func secWebSocketKey(rr io.Reader) (string, error) {
if rr == nil {
rr = rand.Reader
}
b := make([]byte, 16)
_, err := io.ReadFull(rr, b)
if err != nil {
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}
err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return nil, err
}
return verifyServerExtensions(copts, resp.Header)
}
func verifySubprotocol(subprotos []string, resp *http.Response) error {
proto := resp.Header.Get("Sec-WebSocket-Protocol")
if proto == "" {
return nil
}
for _, sp2 := range subprotos {
if strings.EqualFold(sp2, proto) {
return nil
}
}
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}
ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}
copts = &*copts
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
return copts, nil
}
var bufioReaderPool sync.Pool
func getBufioReader(r io.Reader) *bufio.Reader {
br, ok := bufioReaderPool.Get().(*bufio.Reader)
if !ok {
return bufio.NewReader(r)
}
br.Reset(r)
return br
}
func putBufioReader(br *bufio.Reader) {
bufioReaderPool.Put(br)
}
var bufioWriterPool sync.Pool
func getBufioWriter(w io.Writer) *bufio.Writer {
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
if !ok {
return bufio.NewWriter(w)
}
bw.Reset(w)
return bw
}
func putBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
package websocket
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"math/bits"
"nhooyr.io/websocket/internal/errd"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
type header struct {
fin bool
rsv1 bool
rsv2 bool
rsv3 bool
opcode opcode
payloadLength int64
masked bool
maskKey uint32
}
// readFrameHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
defer errd.Wrap(&err, "failed to read frame header")
b, err := r.ReadByte()
if err != nil {
return header{}, err
}
h.fin = b&(1<<7) != 0
h.rsv1 = b&(1<<6) != 0
h.rsv2 = b&(1<<5) != 0
h.rsv3 = b&(1<<4) != 0
h.opcode = opcode(b & 0xf)
b, err = r.ReadByte()
if err != nil {
return header{}, err
}
h.masked = b&(1<<7) != 0
payloadLength := b &^ (1 << 7)
switch {
case payloadLength < 126:
h.payloadLength = int64(payloadLength)
case payloadLength == 126:
_, err = io.ReadFull(r, readBuf[:2])
h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
case payloadLength == 127:
_, err = io.ReadFull(r, readBuf)
h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
}
if err != nil {
return header{}, err
}
if h.payloadLength < 0 {
return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
}
if h.masked {
_, err = io.ReadFull(r, readBuf[:4])
if err != nil {
return header{}, err
}
h.maskKey = binary.LittleEndian.Uint32(readBuf)
}
return h, nil
}
// maxControlPayload is the maximum length of a control frame payload.
// See https://tools.ietf.org/html/rfc6455#section-5.5.
const maxControlPayload = 125
// writeFrameHeader writes the bytes of the header to w.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
defer errd.Wrap(&err, "failed to write frame header")
var b byte
if h.fin {
b |= 1 << 7
}
if h.rsv1 {
b |= 1 << 6
}
if h.rsv2 {
b |= 1 << 5
}
if h.rsv3 {
b |= 1 << 4
}
b |= byte(h.opcode)
err = w.WriteByte(b)
if err != nil {
return err
}
lengthByte := byte(0)
if h.masked {
lengthByte |= 1 << 7
}
switch {
case h.payloadLength > math.MaxUint16:
lengthByte |= 127
case h.payloadLength > 125:
lengthByte |= 126
case h.payloadLength >= 0:
lengthByte |= byte(h.payloadLength)
}
err = w.WriteByte(lengthByte)
if err != nil {
return err
}
switch {
case h.payloadLength > math.MaxUint16:
binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
_, err = w.Write(buf)
case h.payloadLength > 125:
binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
_, err = w.Write(buf[:2])
}
if err != nil {
return err
}
if h.masked {
binary.LittleEndian.PutUint32(buf, h.maskKey)
_, err = w.Write(buf[:4])
if err != nil {
return err
}
}
return nil
}
// mask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}
// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}
// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
package bpool
import (
"bytes"
"sync"
)
var bpool sync.Pool
// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
b := bpool.Get()
if b == nil {
return &bytes.Buffer{}
}
return b.(*bytes.Buffer)
}
// Put returns a buffer into the pool.
func Put(b *bytes.Buffer) {
b.Reset()
bpool.Put(b)
}
package errd
import (
"fmt"
)
// Wrap wraps err with fmt.Errorf if err is non nil.
// Intended for use with defer and a named error return.
// Inspired by https://github.com/golang/go/issues/32676.
func Wrap(err *error, f string, v ...interface{}) {
if *err != nil {
*err = fmt.Errorf(f+": %w", append(v, *err)...)
}
}
package xsync
import (
"fmt"
)
// Go allows running a function in another goroutine
// and waiting for its error.
func Go(fn func() error) <-chan error {
errs := make(chan error, 1)
go func() {
defer func() {
r := recover()
if r != nil {
select {
case errs <- fmt.Errorf("panic in go fn: %v", r):
default:
}
}
}()
errs <- fn()
}()
return errs
}
package xsync
import (
"sync/atomic"
)
// Int64 represents an atomic int64.
type Int64 struct {
// We do not use atomic.Load/StoreInt64 since it does not
// work on 32 bit computers but we need 64 bit integers.
i atomic.Value
}
// Load loads the int64.
func (v *Int64) Load() int64 {
i, _ := v.i.Load().(int64)
return i
}
// Store stores the int64.
func (v *Int64) Store(i int64) {
v.i.Store(i)
}
package websocket
import (
"context"
"fmt"
"io"
"math"
"net"
"sync"
"time"
)
// NetConn converts a *websocket.Conn into a net.Conn.
//
// It's for tunneling arbitrary protocols over WebSockets.
// Few users of the library will need this but it's tricky to implement
// correctly and so provided in the library.
// See https://github.com/nhooyr/websocket/issues/100.
//
// Every Write to the net.Conn will correspond to a message write of
// the given type on *websocket.Conn.
//
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
// all reads and writes on the net.Conn will be cancelled.
//
// If a message is read that is not of the correct type, the connection
// will be closed with StatusUnsupportedData and an error will be returned.
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
//
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
// and "websocket/unknown-addr" for String.
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
nc := &netConn{
c: c,
msgType: msgType,
}
var cancel context.CancelFunc
nc.writeContext, cancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}
nc.readContext, cancel = context.WithCancel(ctx)
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
return nc
}
type netConn struct {
c *Conn
msgType MessageType
writeTimer *time.Timer
writeContext context.Context
readTimer *time.Timer
readContext context.Context
readMu sync.Mutex
eofed bool
reader io.Reader
}
var _ net.Conn = &netConn{}
func (c *netConn) Close() error {
return c.c.Close(StatusNormalClosure, "")
}
func (c *netConn) Write(p []byte) (int, error) {
err := c.c.Write(c.writeContext, c.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}
func (c *netConn) Read(p []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
if c.eofed {
return 0, io.EOF
}
if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if err != nil {
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
c.eofed = true
return 0, io.EOF
}
return 0, err
}
if typ != c.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
c.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
c.reader = r
}
n, err := c.reader.Read(p)
if err == io.EOF {
c.reader = nil
err = nil
}
return n, err
}
type websocketAddr struct {
}
func (a websocketAddr) Network() string {
return "websocket"
}
func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}
func (c *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}
func (c *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}
func (c *netConn) SetDeadline(t time.Time) error {
c.SetWriteDeadline(t)
c.SetReadDeadline(t)
return nil
}
func (c *netConn) SetWriteDeadline(t time.Time) error {
if t.IsZero() {
c.writeTimer.Stop()
} else {
c.writeTimer.Reset(t.Sub(time.Now()))
}
return nil
}
func (c *netConn) SetReadDeadline(t time.Time) error {
if t.IsZero() {
c.readTimer.Stop()
} else {
c.readTimer.Reset(t.Sub(time.Now()))
}
return nil
}
// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"strings"
"time"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/xsync"
)
// Reader reads from the connection until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return c.reader(ctx)
}
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := ioutil.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
c.Reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
// We add read one more byte than the limit in case
// there is a fin frame that needs to be read.
c.msgReader.limitReader.limit.Store(n + 1)
}
const defaultReadLimit = 32768
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.readFunc = mr.read
mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
return mr
}
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
func (mr *msgReader) putFlateReader() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
mr.dict.close()
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
if mr.c.client {
putBufioReader(mr.c.br)
mr.c.br = nil
}
}
func (mr *msgReader) flateContextTakeover() bool {
if mr.c.client {
return !mr.c.copts.serverNoContextTakeover
}
return !mr.c.copts.clientNoContextTakeover
}
func (c *Conn) readRSV1Illegal(h header) bool {
// If compression is disabled, rsv1 is illegal.
if !c.flate() {
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
c.writeError(StatusProtocolError, err)
return header{}, err
}
if !c.client && !h.masked {
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
c.writeError(StatusProtocolError, err)
return header{}, err
}
}
}
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
select {
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- ctx:
}
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
select {
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
return header{}, ctx.Err()
default:
c.close(err)
return header{}, err
}
}
select {
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- context.Background():
}
return h, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-c.closed:
return 0, c.closeErr
case c.readTimeout <- ctx:
}
n, err := io.ReadFull(c.br, p)
if err != nil {
select {
case <-c.closed:
return n, c.closeErr
case <-ctx.Done():
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
c.close(err)
return n, err
}
}
select {
case <-c.closed:
return n, c.closeErr
case c.readTimeout <- context.Background():
}
return n, err
}
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
c.writeError(StatusProtocolError, err)
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.readControlBuf[:h.payloadLength]
_, err = c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
mask(h.maskKey, b)
}
switch h.opcode {
case opPing:
return c.writeControl(ctx, opPong, b)
case opPong:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
select {
case pong <- struct{}{}:
default:
}
}
return nil
}
defer func() {
c.readCloseFrameErr = err
}()
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
c.writeError(StatusProtocolError, err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
c.setCloseErr(err)
c.writeClose(ce.Code, ce.Reason)
c.close(err)
return err
}
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.lock(ctx)
if err != nil {
return 0, nil, err
}
defer c.readMu.unlock()
if !c.msgReader.fin {
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}
h, err := c.readLoop(ctx)
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
c.writeError(StatusProtocolError, err)
return 0, nil, err
}
c.msgReader.reset(ctx, h)
return MessageType(h.opcode), c.msgReader, nil
}
type msgReader struct {
c *Conn
ctx context.Context
flate bool
flateReader io.Reader
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict slidingWindow
fin bool
payloadLength int64
maskKey uint32
// readerFunc(mr.Read) to avoid continuous allocations.
readFunc readerFunc
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.flate = h.rsv1
mr.limitReader.reset(mr.readFunc)
if mr.flate {
mr.resetFlate()
}
mr.setFrame(h)
}
func (mr *msgReader) setFrame(h header) {
mr.fin = h.fin
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()
n, err = mr.limitReader.Read(p)
if mr.flate && mr.flateContextTakeover() {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
}
return n, err
}
func (mr *msgReader) read(p []byte) (int, error) {
for {
if mr.payloadLength == 0 {
if mr.fin {
if mr.flate {
return mr.flateTail.Read(p)
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
mr.c.writeError(StatusProtocolError, err)
return 0, err
}
mr.setFrame(h)
continue
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
n, err := mr.c.readFramePayload(mr.ctx, p)
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
if !mr.c.client {
mr.maskKey = mask(mr.maskKey, p)
}
return n, nil
}
}
type limitReader struct {
c *Conn
r io.Reader
limit xsync.Int64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
lr.reset(r)
return lr
}
func (lr *limitReader) reset(r io.Reader) {
lr.n = lr.limit.Load()
lr.r = r
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n <= 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
return n, err
}
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
// +build !js
package websocket
import (
"bufio"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"time"
"github.com/klauspost/compress/flate"
"nhooyr.io/websocket/internal/errd"
)
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
return w, nil
}
// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type msgWriter struct {
mw *msgWriterState
closed bool
}
func (mw *msgWriter) Write(p []byte) (int, error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
return mw.mw.Write(p)
}
func (mw *msgWriter) Close() error {
if mw.closed {
return errors.New("cannot use closed writer")
}
mw.closed = true
return mw.mw.Close()
}
type msgWriterState struct {
c *Conn
mu *mu
writeMu *mu
ctx context.Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
dict slidingWindow
}
func newMsgWriterState(c *Conn) *msgWriterState {
mw := &msgWriterState{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
func (mw *msgWriterState) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
}
}
mw.dict.init(8192)
mw.flate = true
}
func (mw *msgWriterState) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriterState.reset(ctx, typ)
if err != nil {
return nil, err
}
return &msgWriter{
mw: c.msgWriterState,
closed: false,
}, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !c.flate() {
defer c.msgWriterState.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
}
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = mw.Close()
return n, err
}
func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
}
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.trimWriter.reset()
return nil
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.c.close(err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
mw.ensureFlate()
}
}
if mw.flate {
err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf)
if err != nil {
return 0, err
}
mw.dict.write(p)
return len(p), nil
}
return mw.write(p)
}
func (mw *msgWriterState) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
mw.opcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
func (mw *msgWriterState) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.flate && !mw.flateContextTakeover() {
mw.dict.close()
}
mw.mu.unlock()
return nil
}
func (mw *msgWriterState) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
}
mw.writeMu.forceLock()
mw.dict.close()
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}
// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
// If the state says a close has already been written, we wait until
// the connection is closed and return that error.
//
// However, if the frame being written is a close, that means its the close from
// the state being set so we let it go through.
c.closeMu.Lock()
wroteClose := c.wroteClose
c.closeMu.Unlock()
if wroteClose && opcode != opClose {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-c.closed:
return 0, c.closeErr
}
}
select {
case <-c.closed:
return 0, c.closeErr
case c.writeTimeout <- ctx:
}
defer func() {
if err != nil {
select {
case <-c.closed:
err = c.closeErr
case <-ctx.Done():
err = ctx.Err()
}
c.close(err)
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if c.client {
c.writeHeader.masked = true
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
}
c.writeHeader.rsv1 = false
if flate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
if err != nil {
return 0, err
}
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
select {
case <-c.closed:
return n, c.closeErr
case c.writeTimeout <- context.Background():
}
return n, nil
}
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !c.writeHeader.masked {
return c.bw.Write(p)
}
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
j := len(p)
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
p = p[j:]
n += j
}
return n, nil
}
type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(writerFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
bw.WriteByte(0)
bw.Flush()
bw.Reset(w)
return writeBuf
}
func (c *Conn) writeError(code StatusCode, err error) {
c.setCloseErr(err)
c.writeClose(code, err.Error())
c.close(nil)
}
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "nhooyr.io/websocket/wsjson"
import (
"context"
"encoding/json"
"fmt"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"nhooyr.io/websocket/internal/errd"
)
// Read reads a JSON message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
return read(ctx, c, v)
}
func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to read JSON message")
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
b := bpool.Get()
defer bpool.Put(b)
_, err = b.ReadFrom(r)
if err != nil {
return err
}
err = json.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
return nil
}
// Write writes the JSON message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
return write(ctx, c, v)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to write JSON message")
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
return err
}
// json.Marshal cannot reuse buffers between calls as it has to return
// a copy of the byte slice but Encoder does as it directly writes to w.
err = json.NewEncoder(w).Encode(v)
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
}
return w.Close()
}
// Package wspb provides helpers for reading and writing protobuf messages.
package wspb // import "nhooyr.io/websocket/wspb"
import (
"bytes"
"context"
"fmt"
"github.com/golang/protobuf/proto"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"nhooyr.io/websocket/internal/errd"
)
// Read reads a protobuf message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
return read(ctx, c, v)
}
func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) {
defer errd.Wrap(&err, "failed to read protobuf message")
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageBinary {
c.Close(websocket.StatusUnsupportedData, "expected binary message")
return fmt.Errorf("expected binary message for protobuf but got: %v", typ)
}
b := bpool.Get()
defer bpool.Put(b)
_, err = b.ReadFrom(r)
if err != nil {
return err
}
err = proto.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf")
return fmt.Errorf("failed to unmarshal protobuf: %w", err)
}
return nil
}
// Write writes the protobuf message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
return write(ctx, c, v)
}
func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) {
defer errd.Wrap(&err, "failed to write protobuf message")
b := bpool.Get()
pb := proto.NewBuffer(b.Bytes())
defer func() {
bpool.Put(bytes.NewBuffer(pb.Bytes()))
}()
err = pb.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal protobuf: %w", err)
}
return c.Write(ctx, websocket.MessageBinary, pb.Bytes())
}