Path: blob/main/components/ws-proxy/pkg/sshproxy/forward.go
2500 views
// Copyright (c) 2021 Gitpod GmbH. All rights reserved.1// Licensed under the GNU Affero General Public License (AGPL).2// See License.AGPL.txt in the project root for license information.34package sshproxy56import (7"io"8"sync"9"time"1011"github.com/gitpod-io/gitpod/common-go/analytics"12"github.com/gitpod-io/gitpod/common-go/log"13tracker "github.com/gitpod-io/gitpod/ws-proxy/pkg/analytics"14"github.com/gitpod-io/golang-crypto/ssh"15"golang.org/x/net/context"16)1718func (s *Server) ChannelForward(ctx context.Context, session *Session, targetConn ssh.Conn, originChannel ssh.NewChannel) {19targetChan, targetReqs, err := targetConn.OpenChannel(originChannel.ChannelType(), originChannel.ExtraData())20if err != nil {21log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).WithError(err).Error("open target channel error")22_ = originChannel.Reject(ssh.ConnectionFailed, "open target channel error")23return24}25defer targetChan.Close()2627originChan, originReqs, err := originChannel.Accept()28if err != nil {29log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).WithError(err).Error("accept origin channel failed")30return31}32if originChannel.ChannelType() == "session" {33originChan = startHeartbeatingChannel(originChan, s.Heartbeater, session)34}35defer originChan.Close()3637maskedReqs := make(chan *ssh.Request, 1)3839go func() {40for req := range originReqs {41switch req.Type {42case "pty-req", "shell":43log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Debugf("forwarding %s request", req.Type)44if channel, ok := originChan.(*heartbeatingChannel); ok && req.Type == "pty-req" {45channel.mux.Lock()46channel.requestedPty = true47channel.mux.Unlock()48}49}50maskedReqs <- req51}52close(maskedReqs)53}()5455originChannelWg := sync.WaitGroup{}56originChannelWg.Add(3)57targetChannelWg := sync.WaitGroup{}58targetChannelWg.Add(3)5960wg := sync.WaitGroup{}61wg.Add(2)6263go func() {64defer wg.Done()65_, _ = io.Copy(targetChan, originChan)66_ = targetChan.CloseWrite()67targetChannelWg.Done()68targetChannelWg.Wait()69_ = targetChan.Close()70}()7172go func() {73defer wg.Done()74_, _ = io.Copy(originChan, targetChan)75_ = originChan.CloseWrite()76originChannelWg.Done()77originChannelWg.Wait()78_ = originChan.Close()79}()8081go func() {82_, _ = io.Copy(targetChan.Stderr(), originChan.Stderr())83targetChannelWg.Done()84}()8586go func() {87_, _ = io.Copy(originChan.Stderr(), targetChan.Stderr())88originChannelWg.Done()89}()9091forward := func(sourceReqs <-chan *ssh.Request, targetChan ssh.Channel, channelWg *sync.WaitGroup) {92defer channelWg.Done()93for ctx.Err() == nil {94select {95case req, ok := <-sourceReqs:96if !ok {97return98}99b, err := targetChan.SendRequest(req.Type, req.WantReply, req.Payload)100_ = req.Reply(b, nil)101if err != nil {102return103}104case <-ctx.Done():105return106}107}108}109110go forward(maskedReqs, targetChan, &targetChannelWg)111go forward(targetReqs, originChan, &originChannelWg)112113wg.Wait()114log.WithFields(log.OWI("", session.WorkspaceID, session.InstanceID)).Debug("session forward stop")115}116117func TrackIDECloseSignal(session *Session) {118propertics := make(map[string]interface{})119propertics["workspaceId"] = session.WorkspaceID120propertics["instanceId"] = session.InstanceID121propertics["clientKind"] = "ssh"122tracker.Track(analytics.TrackMessage{123Identity: analytics.Identity{UserID: session.OwnerUserId},124Event: "ide_close_signal",125Properties: propertics,126})127}128129func startHeartbeatingChannel(c ssh.Channel, heartbeat Heartbeat, session *Session) ssh.Channel {130ctx, cancel := context.WithCancel(context.Background())131res := &heartbeatingChannel{132Channel: c,133t: time.NewTicker(30 * time.Second),134cancel: cancel,135}136go func() {137for {138select {139case <-res.t.C:140res.mux.Lock()141if !res.sawActivity || !res.requestedPty {142res.mux.Unlock()143continue144}145res.sawActivity = false146res.mux.Unlock()147heartbeat.SendHeartbeat(session.InstanceID, false, false)148case <-ctx.Done():149if res.requestedPty {150heartbeat.SendHeartbeat(session.InstanceID, true, false)151TrackIDECloseSignal(session)152log.WithField("instanceId", session.InstanceID).Info("send closed heartbeat")153}154return155}156}157}()158159return res160}161162type heartbeatingChannel struct {163ssh.Channel164165mux sync.Mutex166sawActivity bool167t *time.Ticker168169cancel context.CancelFunc170171requestedPty bool172}173174// Read reads up to len(data) bytes from the channel.175func (c *heartbeatingChannel) Read(data []byte) (written int, err error) {176written, err = c.Channel.Read(data)177if err == nil && written != 0 {178c.mux.Lock()179c.sawActivity = true180c.mux.Unlock()181}182return183}184185func (c *heartbeatingChannel) Close() error {186c.t.Stop()187c.cancel()188return c.Channel.Close()189}190191192