package vz
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"os"
"sync"
"syscall"
"time"
"github.com/balajiv113/fd"
"github.com/sirupsen/logrus"
)
func PassFDToUnix(unixSock string) (*os.File, error) {
unixAddr, err := net.ResolveUnixAddr("unix", unixSock)
if err != nil {
return nil, err
}
unixConn, err := net.DialUnix("unix", nil, unixAddr)
if err != nil {
return nil, err
}
server, client, err := createSockPair()
if err != nil {
return nil, err
}
err = fd.Put(unixConn, server)
if err != nil {
return nil, err
}
return client, nil
}
func DialQemu(ctx context.Context, unixSock string) (*os.File, error) {
var dialer net.Dialer
unixConn, err := dialer.DialContext(ctx, "unix", unixSock)
if err != nil {
return nil, err
}
qemuConn := &qemuPacketConn{Conn: unixConn}
server, client, err := createSockPair()
if err != nil {
return nil, err
}
dgramConn, err := net.FileConn(server)
if err != nil {
return nil, err
}
vzConn := &packetConn{Conn: dgramConn}
go forwardPackets(qemuConn, vzConn)
return client, nil
}
func forwardPackets(qemuConn *qemuPacketConn, vzConn *packetConn) {
defer qemuConn.Close()
defer vzConn.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
if _, err := io.Copy(qemuConn, vzConn); err != nil {
logrus.Errorf("Failed to forward packets from VZ to VMNET: %s", err)
}
}()
go func() {
defer wg.Done()
if _, err := io.Copy(vzConn, qemuConn); err != nil {
logrus.Errorf("Failed to forward packets from VMNET to VZ: %s", err)
}
}()
wg.Wait()
}
type qemuPacketConn struct {
net.Conn
}
func (c *qemuPacketConn) Read(b []byte) (n int, err error) {
var size uint32
if err := binary.Read(c.Conn, binary.BigEndian, &size); err != nil {
return 0, err
}
return io.ReadFull(c.Conn, b[:size])
}
func (c *qemuPacketConn) Write(b []byte) (int, error) {
size := len(b)
header := uint32(size)
if err := binary.Write(c.Conn, binary.BigEndian, header); err != nil {
return 0, err
}
for len(b) != 0 {
n, err := c.Conn.Write(b)
if err != nil {
return 0, err
}
b = b[n:]
}
return size, nil
}
const writeRetryDelay = 100 * time.Microsecond
type packetConn struct {
net.Conn
}
func (c *packetConn) Write(b []byte) (int, error) {
var retries uint64
for {
n, err := c.Conn.Write(b)
if n == 0 && err != nil && errors.Is(err, syscall.ENOBUFS) {
time.Sleep(writeRetryDelay)
retries++
continue
}
if err != nil {
return 0, err
}
if n < len(b) {
return n, errors.New("incomplete write to unixgram socket")
}
if retries > 0 {
logrus.Debugf("Write completed after %d retries", retries)
}
return n, nil
}
}