Path: blob/main/mitm-socket/go/domain_socket_piper.go
1029 views
package main12import (3"errors"4"fmt"5"io"6"net"7"sync"8"time"9)1011type DomainSocketPiper struct {12id int13client net.Conn14isClosed bool15wg sync.WaitGroup16signals *Signals17debug bool18}1920func (piper *DomainSocketPiper) Pipe(remote net.Conn) {21client := piper.client2223piper.wg.Add(2)24clientHasDataChan := make(chan bool, 1)25// Pipe data26go piper.copy(client, remote, clientHasDataChan, true)27go piper.copy(remote, client, clientHasDataChan, false)2829piper.wg.Wait()30SendToIpc(piper.id, "closing", nil)31}3233func (piper *DomainSocketPiper) copy(dst net.Conn, src net.Conn, clientHasData chan bool, isReadingFromRemote bool) {34var totalBytes int6435var n int36var w int37var neterr net.Error38var ok bool39var writeErr error40var readErr error41var direction string42var waitForData bool4344if piper.debug {45if isReadingFromRemote {46direction = "from remote"47} else {48direction = "from client"49}50}5152data := make([]byte, 5*1096)5354defer piper.wg.Done()5556for {57if isReadingFromRemote == true && waitForData {58select {59case <-clientHasData:60waitForData = false61case <-time.After(50 * time.Millisecond):62if piper.signals.IsClosed || piper.isClosed {63return64}65}66if waitForData {67continue68}69}70src.SetReadDeadline(time.Now().Add(2 * time.Second)) // Set the deadline71n, readErr = src.Read(data)7273if n > 0 {74if isReadingFromRemote == false && len(clientHasData) == 0 {75clientHasData <- true76}77w, writeErr = dst.Write(data[0:n])78if w < 0 || n < w {79w = 080if writeErr == nil {81writeErr = errors.New("invalid write result")82}83}84totalBytes += int64(w)8586if writeErr == nil && n != w {87writeErr = io.ErrShortWrite88}89if writeErr != nil {90SendErrorToIpc(piper.id, "writeErr", writeErr)91piper.isClosed = true92return93}94}9596if piper.debug {97fmt.Printf("[id=%d] Read %d bytes %s. Total: %d\n", piper.id, n, direction, totalBytes)98}99100if n == 0 && readErr == io.EOF {101if isReadingFromRemote {102if totalBytes == 0 {103piper.isClosed = true104return105}106107SendToIpc(piper.id, "eof", nil)108if len(clientHasData) > 0 {109// drain110<-clientHasData111}112waitForData = true113} else {114piper.isClosed = true115return116}117}118119if readErr != nil && readErr != io.EOF {120neterr, ok = readErr.(net.Error)121// if not a timeout, stop and return122if !ok || !neterr.Timeout() {123SendErrorToIpc(piper.id, "readErr", readErr)124piper.isClosed = true125return126}127}128129if piper.signals.IsClosed || piper.isClosed {130return131}132133if n == 0 || readErr != nil {134time.Sleep(50 * time.Millisecond)135}136}137}138139func (piper *DomainSocketPiper) Close() {140piper.isClosed = true141}142143144