Path: blob/dev/pkg/protocols/http/race/syncedreadcloser.go
2073 views
package race12import (3"fmt"4"io"5"time"6)78// SyncedReadCloser is compatible with io.ReadSeeker and performs9// gate-based synced writes to enable race condition testing.10type SyncedReadCloser struct {11data []byte12p int6413length int6414openGate chan struct{}15enableBlocking bool16}1718// NewSyncedReadCloser creates a new SyncedReadCloser instance.19func NewSyncedReadCloser(r io.ReadCloser) *SyncedReadCloser {20var (21s SyncedReadCloser22err error23)24s.data, err = io.ReadAll(r)25if err != nil {26return nil27}28defer func() {29_ = r.Close()30}()31s.length = int64(len(s.data))32s.openGate = make(chan struct{})33s.enableBlocking = true34return &s35}3637// NewOpenGateWithTimeout creates a new open gate with a timeout38func NewOpenGateWithTimeout(r io.ReadCloser, d time.Duration) *SyncedReadCloser {39s := NewSyncedReadCloser(r)40s.OpenGateAfter(d)41return s42}4344// SetOpenGate sets the status of the blocking gate45func (s *SyncedReadCloser) SetOpenGate(status bool) {46s.enableBlocking = status47}4849// OpenGate opens the gate allowing all requests to be completed50func (s *SyncedReadCloser) OpenGate() {51s.openGate <- struct{}{}52}5354// OpenGateAfter schedules gate to be opened after a duration55func (s *SyncedReadCloser) OpenGateAfter(d time.Duration) {56time.AfterFunc(d, func() {57s.openGate <- struct{}{}58})59}6061// Seek implements seek method for io.ReadSeeker62func (s *SyncedReadCloser) Seek(offset int64, whence int) (int64, error) {63var err error64switch whence {65case io.SeekStart:66s.p = 067case io.SeekCurrent:68if s.p+offset < s.length {69s.p += offset70break71}72err = fmt.Errorf("offset is too big")73case io.SeekEnd:74if s.length-offset >= 0 {75s.p = s.length - offset76break77}78err = fmt.Errorf("offset is too big")79}80return s.p, err81}8283// Read implements read method for io.ReadSeeker84func (s *SyncedReadCloser) Read(p []byte) (n int, err error) {85// If the data fits in the buffer blocks awaiting the sync instruction86if s.p+int64(len(p)) >= s.length && s.enableBlocking {87<-s.openGate88}89n = copy(p, s.data[s.p:])90s.p += int64(n)91if s.p == s.length {92err = io.EOF93}94return n, err95}9697// Close closes an io.ReadSeeker98func (s *SyncedReadCloser) Close() error {99return nil100}101102// Len returns the length of data in reader103func (s *SyncedReadCloser) Len() int {104return int(s.length)105}106107108