// +build !js package websocket import ( "bytes" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "log" "net/http" "net/textproto" "net/url" "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // // You probably want to use OriginPatterns instead. InsecureSkipVerify bool // OriginPatterns lists the host patterns for authorized origins. // The request host is always authorized. // Use this to enable cross origin WebSockets. // // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. // In such a case, example.com is the origin and chat.example.com is the request host. // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host // with filepath.Match. // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. // // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead // to bring attention to the danger of such a setting. OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // // Accept will not allow cross origin requests by default. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") if opts == nil { opts = &AcceptOptions{} } opts = &*opts errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { if errors.Is(err, filepath.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } } hj, ok := w.(http.Hijacker) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { w.Header().Set("Sec-WebSocket-Protocol", subproto) } copts, err := acceptCompression(r, w, opts.CompressionMode) if err != nil { return nil, err } w.WriteHeader(http.StatusSwitchingProtocols) // See https://github.com/nhooyr/websocket/issues/166 if ginWriter, ok := w.(interface { WriteHeaderNow() }); ok { ginWriter.WriteHeaderNow() } netConn, brw, err := hj.Hijack() if err != nil { err = fmt.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } // https://github.com/golang/go/issues/32314 b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, copts: copts, flateThreshold: opts.CompressionThreshold, br: brw.Reader, bw: brw.Writer, }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { w.Header().Set("Sec-WebSocket-Version", "13") return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } return 0, nil } func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") if origin == "" { return nil } u, err := url.Parse(origin) if err != nil { return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) } if strings.EqualFold(r.Host, u.Host) { return nil } for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } if matched { return nil } } return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } func match(pattern, s string) (bool, error) { return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") for _, sp := range subprotocols { for _, cp := range cps { if strings.EqualFold(sp, cp) { return cp } } } return "" } func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { if mode == CompressionDisabled { return nil, nil } for _, ext := range websocketExtensions(r.Header) { switch ext.name { case "permessage-deflate": return acceptDeflate(w, ext, mode) // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 // case "x-webkit-deflate-frame": // return acceptWebkitDeflate(w, ext, mode) } } return nil, nil } func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } if strings.HasPrefix(p, "client_max_window_bits") { // We cannot adjust the read sliding window so cannot make use of this. continue } err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } copts.setHeader(w.Header()) return copts, nil } func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. copts.serverNoContextTakeover = false for _, p := range ext.params { if p == "no_context_takeover" { copts.serverNoContextTakeover = true continue } // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead // of ignoring it as the draft spec is unclear. It says the server can ignore it // but the server has no way of signalling to the client it was ignored as the parameters // are set one way. // Thus us ignoring it would make the client think we understood it which would cause issues. // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 // // Either way, we're only implementing this for webkit which never sends the max_window_bits // parameter so we don't need to worry about it. err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } s := "x-webkit-deflate-frame" if copts.clientNoContextTakeover { s += "; no_context_takeover" } w.Header().Set("Sec-WebSocket-Extensions", s) return copts, nil } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { if strings.EqualFold(t, token) { return true } } return false } type websocketExtension struct { name string params []string } func websocketExtensions(h http.Header) []websocketExtension { var exts []websocketExtension extStrs := headerTokens(h, "Sec-WebSocket-Extensions") for _, extStr := range extStrs { if extStr == "" { continue } vals := strings.Split(extStr, ";") for i := range vals { vals[i] = strings.TrimSpace(vals[i]) } e := websocketExtension{ name: vals[0], params: vals[1:], } exts = append(exts, e) } return exts } func headerTokens(h http.Header, key string) []string { key = textproto.CanonicalMIMEHeaderKey(key) var tokens []string for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { t = strings.TrimSpace(t) tokens = append(tokens, t) } } return tokens } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func secWebSocketAccept(secWebSocketKey string) string { h := sha1.New() h.Write([]byte(secWebSocketKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) }
package websocket import ( "errors" "fmt" ) // StatusCode represents a WebSocket status code. // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // // These are only the status codes defined by the protocol. // // You can define custom codes in the 3000-4999 range. // The 3000-3999 range is reserved for use by libraries, frameworks and applications. // The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without // a status code. StatusNoStatusRcvd StatusCode = 1005 // StatusAbnormalClosure is exported for use only with Wasm. // In non Wasm Go, the returned error will indicate whether the // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 StatusPolicyViolation StatusCode = 1008 StatusMessageTooBig StatusCode = 1009 StatusMandatoryExtension StatusCode = 1010 StatusInternalError StatusCode = 1011 StatusServiceRestart StatusCode = 1012 StatusTryAgainLater StatusCode = 1013 StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. // In non Wasm Go, the returned error will indicate whether there was // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) // CloseError is returned when the connection is closed with a status and reason. // // Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string } func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { return ce.Code } return -1 }
// +build !js package websocket import ( "context" "encoding/binary" "errors" "fmt" "log" "time" "nhooyr.io/websocket/internal/errd" ) // Close performs the WebSocket close handshake with the given status code and reason. // // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for // the peer to send a close frame. // All data messages received from the peer during the close handshake will be discarded. // // The connection can only be closed once. Additional calls to Close // are no-ops. // // The maximum length of reason must be 125 bytes. Avoid // sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { return c.closeHandshake(code, reason) } func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") writeErr := c.writeClose(code, reason) closeHandshakeErr := c.waitCloseHandshake() if writeErr != nil { return writeErr } if CloseStatus(closeHandshakeErr) == -1 { return closeHandshakeErr } return nil } var errAlreadyWroteClose = errors.New("already wrote close") func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() if wroteClose { return errAlreadyWroteClose } ce := CloseError{ Code: code, Reason: reason, } var p []byte var marshalErr error if ce.Code != StatusNoStatusRcvd { p, marshalErr = ce.bytes() if marshalErr != nil { log.Printf("websocket: %v", marshalErr) } } writeErr := c.writeControl(context.Background(), opClose, p) if CloseStatus(writeErr) != -1 { // Not a real error if it's due to a close frame being received. writeErr = nil } // We do this after in case there was an error writing the close frame. c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) if marshalErr != nil { return marshalErr } return writeErr } func (c *Conn) waitCloseHandshake() error { defer c.close(nil) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) if err != nil { return err } defer c.readMu.unlock() if c.readCloseFrameErr != nil { return c.readCloseFrameErr } for { h, err := c.readLoop(ctx) if err != nil { return err } for i := int64(0); i < h.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } } } func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ Code: StatusNoStatusRcvd, }, nil } if len(p) < 2 { return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ Code: StatusCode(binary.BigEndian.Uint16(p)), Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) } return ce, nil } // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { switch code { case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: return false } if code >= StatusNormalClosure && code <= StatusBadGateway { return true } if code >= 3000 && code <= 4999 { return true } return false } func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { err = fmt.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } p, _ = ce.bytesErr() } return p, err } const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) binary.BigEndian.PutUint16(buf, uint16(ce.Code)) copy(buf[2:], ce.Reason) return buf, nil } func (c *Conn) setCloseErr(err error) { c.closeMu.Lock() c.setCloseErrLocked(err) c.closeMu.Unlock() } func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } }
// +build !js package websocket import ( "io" "net/http" "sync" "github.com/klauspost/compress/flate" ) func (m CompressionMode) opts() *compressionOptions { return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, } } type compressionOptions struct { clientNoContextTakeover bool serverNoContextTakeover bool } func (copts *compressionOptions) setHeader(h http.Header) { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" } if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } h.Set("Sec-WebSocket-Extensions", s) } // These bytes are required to get flate.Reader to return. // They are removed when sending to avoid the overhead as // WebSocket framing tell's when the message has ended but then // we need to add them back otherwise flate.Reader keeps // trying to return more bytes. const deflateMessageTail = "\x00\x00\xff\xff" type trimLastFourBytesWriter struct { w io.Writer tail []byte } func (tw *trimLastFourBytesWriter) reset() { if tw != nil && tw.tail != nil { tw.tail = tw.tail[:0] } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if tw.tail == nil { tw.tail = make([]byte, 0, 4) } extra := len(tw.tail) + len(p) - 4 if extra <= 0 { tw.tail = append(tw.tail, p...) return len(p), nil } // Now we need to write as many extra bytes as we can from the previous tail. if extra > len(tw.tail) { extra = len(tw.tail) } if extra > 0 { _, err := tw.w.Write(tw.tail[:extra]) if err != nil { return 0, err } // Shift remaining bytes in tail over. n := copy(tw.tail, tw.tail[extra:]) tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, // all of it is is part of the tail. if len(p) <= 4 { tw.tail = append(tw.tail, p...) return len(p), nil } // Otherwise, only the last 4 bytes are. tw.tail = append(tw.tail, p[len(p)-4:]...) p = p[:len(p)-4] n, err := tw.w.Write(p) return n + 4, err } var flateReaderPool sync.Pool func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { return flate.NewReaderDict(r, dict) } fr.(flate.Resetter).Reset(r, dict) return fr } func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } type slidingWindow struct { buf []byte } var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} func slidingWindowPool(n int) *sync.Pool { swPoolMu.RLock() p, ok := swPool[n] swPoolMu.RUnlock() if ok { return p } p = &sync.Pool{} swPoolMu.Lock() swPool[n] = p swPoolMu.Unlock() return p } func (sw *slidingWindow) init(n int) { if sw.buf != nil { return } if n == 0 { n = 32768 } p := slidingWindowPool(n) buf, ok := p.Get().([]byte) if ok { sw.buf = buf[:0] } else { sw.buf = make([]byte, 0, n) } } func (sw *slidingWindow) close() { if sw.buf == nil { return } swPoolMu.Lock() swPool[cap(sw.buf)].Put(sw.buf) swPoolMu.Unlock() sw.buf = nil } func (sw *slidingWindow) write(p []byte) { if len(p) >= cap(sw.buf) { sw.buf = sw.buf[:cap(sw.buf)] p = p[len(p)-cap(sw.buf):] copy(sw.buf, p) return } left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left copy(sw.buf, sw.buf[spaceNeeded:]) sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } sw.buf = append(sw.buf, p...) }
// +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "runtime" "strconv" "sync" "sync/atomic" ) // Conn represents a WebSocket connection. // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control // frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. // // On any error from any method, the connection is closed // with an appropriate reason. type Conn struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context // Read state. readMu *mu readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error // Write state. msgWriterState *msgWriterState writeFrameMu *mu writeBuf []byte writeHeaderBuf [8]byte writeHeader header closed chan struct{} closeMu sync.Mutex closeErr error wroteClose bool pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } type connConfig struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer } func newConn(cfg connConfig) *Conn { c := &Conn{ subprotocol: cfg.subprotocol, rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, readTimeout: make(chan context.Context), writeTimeout: make(chan context.Context), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } c.readMu = newMu(c) c.writeFrameMu = newMu(c) c.msgReader = newMsgReader(c) c.msgWriterState = newMsgWriterState(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 128 if !c.msgWriterState.flateContextTakeover() { c.flateThreshold = 512 } } runtime.SetFinalizer(c, func(c *Conn) { c.close(errors.New("connection garbage collected")) }) go c.timeoutLoop() return c } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.subprotocol } func (c *Conn) close(err error) { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { return } c.setCloseErrLocked(err) close(c.closed) runtime.SetFinalizer(c, nil) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. c.rwc.Close() go func() { c.msgWriterState.close() c.msgReader.close() }() } func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() for { select { case <-c.closed: return case writeCtx = <-c.writeTimeout: case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } } func (c *Conn) flate() bool { return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does // not read from the connection but instead waits for a Reader call // to read the pong. // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { p := atomic.AddInt32(&c.pingCounter, 1) err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } return nil } func (c *Conn) ping(ctx context.Context, p string) error { pong := make(chan struct{}, 1) c.activePingsMu.Lock() c.activePings[p] = pong c.activePingsMu.Unlock() defer func() { c.activePingsMu.Lock() delete(c.activePings, p) c.activePingsMu.Unlock() }() err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } select { case <-c.closed: return c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) c.close(err) return err case <-pong: return nil } } type mu struct { c *Conn ch chan struct{} } func newMu(c *Conn) *mu { return &mu{ c: c, ch: make(chan struct{}, 1), } } func (m *mu) forceLock() { m.ch <- struct{}{} } func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: return m.c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) m.c.close(err) return err case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected // over the receive on closed. select { case <-m.c.closed: // Make sure to release. m.unlock() return m.c.closeErr default: } return nil } } func (m *mu) unlock() { select { case <-m.ch: default: } }
// +build !js package websocket import ( "bufio" "bytes" "context" "crypto/rand" "encoding/base64" "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" "sync" "time" "nhooyr.io/websocket/internal/errd" ) // DialOptions represents Dial's options. type DialOptions struct { // HTTPClient is used for the connection. // Its Transport must return writable bodies for WebSocket handshakes. // http.Transport does beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. // You never need to close resp.Body yourself. // // If an error occurs, the returned response may be non nil. // However, you can only read the first 1024 bytes of the body. // // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // // URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") if opts == nil { opts = &DialOptions{} } opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } else if opts.HTTPClient.Timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) defer cancel() newClient := *opts.HTTPClient newClient.Timeout = 0 opts.HTTPClient = &newClient } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } var copts *compressionOptions if opts.CompressionMode != CompressionDisabled { copts = opts.CompressionMode.opts() } resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) if err != nil { return nil, resp, err } respBody := resp.Body resp.Body = nil defer func() { if err != nil { // We read a bit of the body for easier debugging. r := io.LimitReader(respBody, 1024) timer := time.AfterFunc(time.Second*3, func() { respBody.Close() }) defer timer.Stop() b, _ := ioutil.ReadAll(r) respBody.Close() resp.Body = ioutil.NopCloser(bytes.NewReader(b)) } }() copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) if err != nil { return nil, resp, err } rwc, ok := respBody.(io.ReadWriteCloser) if !ok { return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), rwc: rwc, client: true, copts: copts, flateThreshold: opts.CompressionThreshold, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { copts.setHeader(req.Header) } resp, err := opts.HTTPClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send handshake request: %w", err) } return resp, nil } func secWebSocketKey(rr io.Reader) (string, error) { if rr == nil { rr = rand.Reader } b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) } err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { return nil, err } return verifyServerExtensions(copts, resp.Header) } func verifySubprotocol(subprotos []string, resp *http.Response) error { proto := resp.Header.Get("Sec-WebSocket-Protocol") if proto == "" { return nil } for _, sp2 := range subprotos { if strings.EqualFold(sp2, proto) { return nil } } return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil } ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } copts = &*copts for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil } var bufioReaderPool sync.Pool func getBufioReader(r io.Reader) *bufio.Reader { br, ok := bufioReaderPool.Get().(*bufio.Reader) if !ok { return bufio.NewReader(r) } br.Reset(r) return br } func putBufioReader(br *bufio.Reader) { bufioReaderPool.Put(br) } var bufioWriterPool sync.Pool func getBufioWriter(w io.Writer) *bufio.Writer { bw, ok := bufioWriterPool.Get().(*bufio.Writer) if !ok { return bufio.NewWriter(w) } bw.Reset(w) return bw } func putBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) }
package websocket import ( "bufio" "encoding/binary" "fmt" "io" "math" "math/bits" "nhooyr.io/websocket/internal/errd" ) // opcode represents a WebSocket opcode. type opcode int // https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText opBinary // 3 - 7 are reserved for further non-control frames. _ _ _ _ _ opClose opPing opPong // 11-16 are reserved for further control frames. ) // header represents a WebSocket frame header. // See https://tools.ietf.org/html/rfc6455#section-5.2. type header struct { fin bool rsv1 bool rsv2 bool rsv3 bool opcode opcode payloadLength int64 masked bool maskKey uint32 } // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { return header{}, err } h.fin = b&(1<<7) != 0 h.rsv1 = b&(1<<6) != 0 h.rsv2 = b&(1<<5) != 0 h.rsv3 = b&(1<<4) != 0 h.opcode = opcode(b & 0xf) b, err = r.ReadByte() if err != nil { return header{}, err } h.masked = b&(1<<7) != 0 payloadLength := b &^ (1 << 7) switch { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: _, err = io.ReadFull(r, readBuf[:2]) h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: _, err = io.ReadFull(r, readBuf) h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } if err != nil { return header{}, err } if h.payloadLength < 0 { return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { _, err = io.ReadFull(r, readBuf[:4]) if err != nil { return header{}, err } h.maskKey = binary.LittleEndian.Uint32(readBuf) } return h, nil } // maxControlPayload is the maximum length of a control frame payload. // See https://tools.ietf.org/html/rfc6455#section-5.5. const maxControlPayload = 125 // writeFrameHeader writes the bytes of the header to w. // See https://tools.ietf.org/html/rfc6455#section-5.2 func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { defer errd.Wrap(&err, "failed to write frame header") var b byte if h.fin { b |= 1 << 7 } if h.rsv1 { b |= 1 << 6 } if h.rsv2 { b |= 1 << 5 } if h.rsv3 { b |= 1 << 4 } b |= byte(h.opcode) err = w.WriteByte(b) if err != nil { return err } lengthByte := byte(0) if h.masked { lengthByte |= 1 << 7 } switch { case h.payloadLength > math.MaxUint16: lengthByte |= 127 case h.payloadLength > 125: lengthByte |= 126 case h.payloadLength >= 0: lengthByte |= byte(h.payloadLength) } err = w.WriteByte(lengthByte) if err != nil { return err } switch { case h.payloadLength > math.MaxUint16: binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) _, err = w.Write(buf) case h.payloadLength > 125: binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) _, err = w.Write(buf[:2]) } if err != nil { return err } if h.masked { binary.LittleEndian.PutUint32(buf, h.maskKey) _, err = w.Write(buf[:4]) if err != nil { return err } } return nil } // mask applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // // The returned value is the correctly rotated key to // to continue to mask/unmask the message. // // It is optimized for LittleEndian and expects the key // to be in little endian. // // See https://github.com/golang/go/issues/31586 func mask(key uint32, b []byte) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) // At some point in the future we can clean these unrolled loops up. // See https://github.com/golang/go/issues/31586#issuecomment-487436401 // Then we xor until b is less than 128 bytes. for len(b) >= 128 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) v = binary.LittleEndian.Uint64(b[64:72]) binary.LittleEndian.PutUint64(b[64:72], v^key64) v = binary.LittleEndian.Uint64(b[72:80]) binary.LittleEndian.PutUint64(b[72:80], v^key64) v = binary.LittleEndian.Uint64(b[80:88]) binary.LittleEndian.PutUint64(b[80:88], v^key64) v = binary.LittleEndian.Uint64(b[88:96]) binary.LittleEndian.PutUint64(b[88:96], v^key64) v = binary.LittleEndian.Uint64(b[96:104]) binary.LittleEndian.PutUint64(b[96:104], v^key64) v = binary.LittleEndian.Uint64(b[104:112]) binary.LittleEndian.PutUint64(b[104:112], v^key64) v = binary.LittleEndian.Uint64(b[112:120]) binary.LittleEndian.PutUint64(b[112:120], v^key64) v = binary.LittleEndian.Uint64(b[120:128]) binary.LittleEndian.PutUint64(b[120:128], v^key64) b = b[128:] } // Then we xor until b is less than 64 bytes. for len(b) >= 64 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) b = b[64:] } // Then we xor until b is less than 32 bytes. for len(b) >= 32 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) b = b[32:] } // Then we xor until b is less than 16 bytes. for len(b) >= 16 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) b = b[16:] } // Then we xor until b is less than 8 bytes. for len(b) >= 8 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) b = b[8:] } } // Then we xor until b is less than 4 bytes. for len(b) >= 4 { v := binary.LittleEndian.Uint32(b) binary.LittleEndian.PutUint32(b, v^key) b = b[4:] } // xor remaining bytes. for i := range b { b[i] ^= byte(key) key = bits.RotateLeft32(key, -8) } return key }
package bpool import ( "bytes" "sync" ) var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() if b == nil { return &bytes.Buffer{} } return b.(*bytes.Buffer) } // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() bpool.Put(b) }
package errd import ( "fmt" ) // Wrap wraps err with fmt.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { *err = fmt.Errorf(f+": %w", append(v, *err)...) } }
package xsync import ( "fmt" ) // Go allows running a function in another goroutine // and waiting for its error. func Go(fn func() error) <-chan error { errs := make(chan error, 1) go func() { defer func() { r := recover() if r != nil { select { case errs <- fmt.Errorf("panic in go fn: %v", r): default: } } }() errs <- fn() }() return errs }
package xsync import ( "sync/atomic" ) // Int64 represents an atomic int64. type Int64 struct { // We do not use atomic.Load/StoreInt64 since it does not // work on 32 bit computers but we need 64 bit integers. i atomic.Value } // Load loads the int64. func (v *Int64) Load() int64 { i, _ := v.i.Load().(int64) return i } // Store stores the int64. func (v *Int64) Store(i int64) { v.i.Store(i) }
package websocket import ( "context" "fmt" "io" "math" "net" "sync" "time" ) // NetConn converts a *websocket.Conn into a net.Conn. // // It's for tunneling arbitrary protocols over WebSockets. // Few users of the library will need this but it's tricky to implement // correctly and so provided in the library. // See https://github.com/nhooyr/websocket/issues/100. // // Every Write to the net.Conn will correspond to a message write of // the given type on *websocket.Conn. // // The passed ctx bounds the lifetime of the net.Conn. If cancelled, // all reads and writes on the net.Conn will be cancelled. // // If a message is read that is not of the correct type, the connection // will be closed with StatusUnsupportedData and an error will be returned. // // Close will close the *websocket.Conn with StatusNormalClosure. // // When a deadline is hit, the connection will be closed. This is // different from most net.Conn implementations where only the // reading/writing goroutines are interrupted but the connection is kept alive. // // The Addr methods will return a mock net.Addr that returns "websocket" for Network // and "websocket/unknown-addr" for String. // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { nc := &netConn{ c: c, msgType: msgType, } var cancel context.CancelFunc nc.writeContext, cancel = context.WithCancel(ctx) nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } nc.readContext, cancel = context.WithCancel(ctx) nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) if !nc.readTimer.Stop() { <-nc.readTimer.C } return nc } type netConn struct { c *Conn msgType MessageType writeTimer *time.Timer writeContext context.Context readTimer *time.Timer readContext context.Context readMu sync.Mutex eofed bool reader io.Reader } var _ net.Conn = &netConn{} func (c *netConn) Close() error { return c.c.Close(StatusNormalClosure, "") } func (c *netConn) Write(p []byte) (int, error) { err := c.c.Write(c.writeContext, c.msgType, p) if err != nil { return 0, err } return len(p), nil } func (c *netConn) Read(p []byte) (int, error) { c.readMu.Lock() defer c.readMu.Unlock() if c.eofed { return 0, io.EOF } if c.reader == nil { typ, r, err := c.c.Reader(c.readContext) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: c.eofed = true return 0, io.EOF } return 0, err } if typ != c.msgType { err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) c.c.Close(StatusUnsupportedData, err.Error()) return 0, err } c.reader = r } n, err := c.reader.Read(p) if err == io.EOF { c.reader = nil err = nil } return n, err } type websocketAddr struct { } func (a websocketAddr) Network() string { return "websocket" } func (a websocketAddr) String() string { return "websocket/unknown-addr" } func (c *netConn) RemoteAddr() net.Addr { return websocketAddr{} } func (c *netConn) LocalAddr() net.Addr { return websocketAddr{} } func (c *netConn) SetDeadline(t time.Time) error { c.SetWriteDeadline(t) c.SetReadDeadline(t) return nil } func (c *netConn) SetWriteDeadline(t time.Time) error { if t.IsZero() { c.writeTimer.Stop() } else { c.writeTimer.Reset(t.Sub(time.Now())) } return nil } func (c *netConn) SetReadDeadline(t time.Time) error { if t.IsZero() { c.readTimer.Stop() } else { c.readTimer.Reset(t.Sub(time.Now())) } return nil }
// +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "io/ioutil" "strings" "time" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket // data message to be read. It will handle ping, pong and close frames as appropriate. // // It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } // Read is a convenience method around Reader to read a single message // from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { return 0, nil, err } b, err := ioutil.ReadAll(r) return typ, b, err } // CloseRead starts a goroutine to read from the connection until it is closed // or a data message is received. // // Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // // If a data message is received, the connection will be closed with StatusPolicyViolation. // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { defer cancel() c.Reader(ctx) c.Close(StatusPolicyViolation, "unexpected data message") }() return ctx } // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { // We add read one more byte than the limit in case // there is a fin frame that needs to be read. c.msgReader.limitReader.limit.Store(n + 1) } const defaultReadLimit = 32768 func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } mr.readFunc = mr.read mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil } } func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() mr.dict.close() if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } if mr.c.client { putBufioReader(mr.c.br) mr.c.br = nil } } func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { return !mr.c.copts.serverNoContextTakeover } return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is disabled, rsv1 is illegal. if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. if h.opcode != opText && h.opcode != opBinary { return true } return false } func (c *Conn) readLoop(ctx context.Context) (header, error) { for { h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } } } func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, c.closeErr case c.readTimeout <- ctx: } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: return header{}, c.closeErr case <-ctx.Done(): return header{}, ctx.Err() default: c.close(err) return header{}, err } } select { case <-c.closed: return header{}, c.closeErr case c.readTimeout <- context.Background(): } return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, c.closeErr case c.readTimeout <- ctx: } n, err := io.ReadFull(c.br, p) if err != nil { select { case <-c.closed: return n, c.closeErr case <-ctx.Done(): return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) c.close(err) return n, err } } select { case <-c.closed: return n, c.closeErr case c.readTimeout <- context.Background(): } return n, err } func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() b := c.readControlBuf[:h.payloadLength] _, err = c.readFramePayload(ctx, b) if err != nil { return err } if h.masked { mask(h.maskKey, b) } switch h.opcode { case opPing: return c.writeControl(ctx, opPong, b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() if ok { select { case pong <- struct{}{}: default: } } return nil } defer func() { c.readCloseFrameErr = err }() ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) c.close(err) return err } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { defer errd.Wrap(&err, "failed to get reader") err = c.readMu.lock(ctx) if err != nil { return 0, nil, err } defer c.readMu.unlock() if !c.msgReader.fin { err = errors.New("previous message not read to completion") c.close(fmt.Errorf("failed to get reader: %w", err)) return 0, nil, err } h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } c.msgReader.reset(ctx, h) return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { c *Conn ctx context.Context flate bool flateReader io.Reader flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict slidingWindow fin bool payloadLength int64 maskKey uint32 // readerFunc(mr.Read) to avoid continuous allocations. readFunc readerFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() } mr.setFrame(h) } func (mr *msgReader) setFrame(h header) { mr.fin = h.fin mr.payloadLength = h.payloadLength mr.maskKey = h.maskKey } func (mr *msgReader) Read(p []byte) (n int, err error) { err = mr.c.readMu.lock(mr.ctx) if err != nil { return 0, fmt.Errorf("failed to read: %w", err) } defer mr.c.readMu.unlock() n, err = mr.limitReader.Read(p) if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { mr.putFlateReader() return n, io.EOF } if err != nil { err = fmt.Errorf("failed to read: %w", err) mr.c.close(err) } return n, err } func (mr *msgReader) read(p []byte) (int, error) { for { if mr.payloadLength == 0 { if mr.fin { if mr.flate { return mr.flateTail.Read(p) } return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) continue } if int64(len(p)) > mr.payloadLength { p = p[:mr.payloadLength] } n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) if !mr.c.client { mr.maskKey = mask(mr.maskKey, p) } return n, nil } } type limitReader struct { c *Conn r io.Reader limit xsync.Int64 n int64 } func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { lr := &limitReader{ c: c, } lr.limit.Store(limit) lr.reset(r) return lr } func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } if int64(len(p)) > lr.n { p = p[:lr.n] } n, err := lr.r.Read(p) lr.n -= int64(n) return n, err } type readerFunc func(p []byte) (int, error) func (f readerFunc) Read(p []byte) (int, error) { return f(p) }
// +build !js package websocket import ( "bufio" "context" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "time" "github.com/klauspost/compress/flate" "nhooyr.io/websocket/internal/errd" ) // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // // You must close the writer once you have written the entire message. // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) } return w, nil } // Write writes a message to the connection. // // See the Writer method if you want to stream a message. // // If compression is disabled or the threshold is not met, then it // will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) } return nil } type msgWriter struct { mw *msgWriterState closed bool } func (mw *msgWriter) Write(p []byte) (int, error) { if mw.closed { return 0, errors.New("cannot use closed writer") } return mw.mw.Write(p) } func (mw *msgWriter) Close() error { if mw.closed { return errors.New("cannot use closed writer") } mw.closed = true return mw.mw.Close() } type msgWriterState struct { c *Conn mu *mu writeMu *mu ctx context.Context opcode opcode flate bool trimWriter *trimLastFourBytesWriter dict slidingWindow } func newMsgWriterState(c *Conn) *msgWriterState { mw := &msgWriterState{ c: c, mu: newMu(c), writeMu: newMu(c), } return mw } func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } } mw.dict.init(8192) mw.flate = true } func (mw *msgWriterState) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } return !mw.c.copts.serverNoContextTakeover } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { err := c.msgWriterState.reset(ctx, typ) if err != nil { return nil, err } return &msgWriter{ mw: c.msgWriterState, closed: false, }, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { mw, err := c.writer(ctx, typ) if err != nil { return 0, err } if !c.flate() { defer c.msgWriterState.mu.unlock() return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) } n, err := mw.Write(p) if err != nil { return n, err } err = mw.Close() return n, err } func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err } mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false mw.trimWriter.reset() return nil } // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) mw.c.close(err) } }() if mw.c.flate() { // Only enables flate if the length crosses the // threshold on the first frame if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { mw.ensureFlate() } } if mw.flate { err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) if err != nil { return 0, err } mw.dict.write(p) return len(p), nil } return mw.write(p) } func (mw *msgWriterState) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } mw.opcode = opContinuation return n, nil } // Close flushes the frame to the connection. func (mw *msgWriterState) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") err = mw.writeMu.lock(mw.ctx) if err != nil { return err } defer mw.writeMu.unlock() _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { mw.dict.close() } mw.mu.unlock() return nil } func (mw *msgWriterState) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw) } mw.writeMu.forceLock() mw.dict.close() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } return nil } // frame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } defer c.writeFrameMu.unlock() // If the state says a close has already been written, we wait until // the connection is closed and return that error. // // However, if the frame being written is a close, that means its the close from // the state being set so we let it go through. c.closeMu.Lock() wroteClose := c.wroteClose c.closeMu.Unlock() if wroteClose && opcode != opClose { select { case <-ctx.Done(): return 0, ctx.Err() case <-c.closed: return 0, c.closeErr } } select { case <-c.closed: return 0, c.closeErr case c.writeTimeout <- ctx: } defer func() { if err != nil { select { case <-c.closed: err = c.closeErr case <-ctx.Done(): err = ctx.Err() } c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p)) if c.client { c.writeHeader.masked = true _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } c.writeHeader.rsv1 = false if flate && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) if err != nil { return 0, err } n, err := c.writeFramePayload(p) if err != nil { return n, err } if c.writeHeader.fin { err = c.bw.Flush() if err != nil { return n, fmt.Errorf("failed to flush: %w", err) } } select { case <-c.closed: return n, c.closeErr case c.writeTimeout <- context.Background(): } return n, nil } func (c *Conn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { return c.bw.Write(p) } maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. if c.bw.Available() == 0 { err = c.bw.Flush() if err != nil { return n, err } } // Start of next write in the buffer. i := c.bw.Buffered() j := len(p) if j > c.bw.Available() { j = c.bw.Available() } _, err := c.bw.Write(p[:j]) if err != nil { return n, err } maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) p = p[j:] n += j } return n, nil } type writerFunc func(p []byte) (int, error) func (f writerFunc) Write(p []byte) (int, error) { return f(p) } // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte bw.Reset(writerFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) bw.WriteByte(0) bw.Flush() bw.Reset(w) return writeBuf } func (c *Conn) writeError(code StatusCode, err error) { c.setCloseErr(err) c.writeClose(code, err.Error()) c.close(nil) }
// Package wsjson provides helpers for reading and writing JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" "fmt" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" ) // Read reads a JSON message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to read JSON message") _, r, err := c.Reader(ctx) if err != nil { return err } b := bpool.Get() defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { return err } err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") return fmt.Errorf("failed to unmarshal JSON: %w", err) } return nil } // Write writes the JSON message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") w, err := c.Writer(ctx, websocket.MessageText) if err != nil { return err } // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. err = json.NewEncoder(w).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } return w.Close() }
// Package wspb provides helpers for reading and writing protobuf messages. package wspb // import "nhooyr.io/websocket/wspb" import ( "bytes" "context" "fmt" "github.com/golang/protobuf/proto" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" ) // Read reads a protobuf message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { defer errd.Wrap(&err, "failed to read protobuf message") typ, r, err := c.Reader(ctx) if err != nil { return err } if typ != websocket.MessageBinary { c.Close(websocket.StatusUnsupportedData, "expected binary message") return fmt.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { return err } err = proto.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") return fmt.Errorf("failed to unmarshal protobuf: %w", err) } return nil } // Write writes the protobuf message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { defer errd.Wrap(&err, "failed to write protobuf message") b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) defer func() { bpool.Put(bytes.NewBuffer(pb.Bytes())) }() err = pb.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) }