1 | 1 | | package tunnel |
2 | 2 | | |
3 | 3 | | import ( |
4 | | - | "errors" |
5 | 4 | | "io" |
6 | 5 | | "net" |
7 | 6 | | "sync" |
| skipped 35 lines |
43 | 42 | | defer remoteConn.Close() |
44 | 43 | | |
45 | 44 | | log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) |
46 | | - | if err = pipe(originConn, remoteConn); err != nil { |
47 | | - | log.Debugf("[TCP] %s <-> %s: %v", metadata.SourceAddress(), metadata.DestinationAddress(), err) |
48 | | - | } |
| 45 | + | pipe(originConn, remoteConn) |
49 | 46 | | } |
50 | 47 | | |
51 | 48 | | // pipe copies copy data to & from provided net.Conn(s) bidirectionally. |
52 | | - | func pipe(origin, remote net.Conn) error { |
| 49 | + | func pipe(origin, remote net.Conn) { |
53 | 50 | | wg := sync.WaitGroup{} |
54 | 51 | | wg.Add(2) |
55 | 52 | | |
56 | | - | var leftErr, rightErr error |
57 | | - | |
58 | | - | go func() { |
59 | | - | defer wg.Done() |
60 | | - | if err := copyBuffer(remote, origin); err != nil { |
61 | | - | leftErr = errors.Join(leftErr, err) |
62 | | - | } |
63 | | - | remote.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) |
64 | | - | }() |
65 | | - | |
66 | | - | go func() { |
67 | | - | defer wg.Done() |
68 | | - | if err := copyBuffer(origin, remote); err != nil { |
69 | | - | rightErr = errors.Join(rightErr, err) |
70 | | - | } |
71 | | - | origin.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) |
72 | | - | }() |
| 53 | + | go unidirectionalStream(remote, origin, "origin->remote", &wg) |
| 54 | + | go unidirectionalStream(origin, remote, "remote->origin", &wg) |
73 | 55 | | |
74 | 56 | | wg.Wait() |
75 | | - | return errors.Join(leftErr, rightErr) |
76 | 57 | | } |
77 | 58 | | |
78 | | - | func copyBuffer(dst io.Writer, src io.Reader) error { |
| 59 | + | func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) { |
| 60 | + | defer wg.Done() |
79 | 61 | | buf := pool.Get(pool.RelayBufferSize) |
80 | | - | defer pool.Put(buf) |
81 | | - | |
82 | | - | _, err := io.CopyBuffer(dst, src, buf) |
83 | | - | return err |
| 62 | + | if _, err := io.CopyBuffer(dst, src, buf); err != nil { |
| 63 | + | log.Debugf("[TCP] copy data for %s: %v", dir, err) |
| 64 | + | } |
| 65 | + | pool.Put(buf) |
| 66 | + | // Do the upload/download side TCP half-close. |
| 67 | + | if cr, ok := src.(interface{ CloseRead() error }); ok { |
| 68 | + | cr.CloseRead() |
| 69 | + | } |
| 70 | + | if cw, ok := dst.(interface{ CloseWrite() error }); ok { |
| 71 | + | cw.CloseWrite() |
| 72 | + | } |
| 73 | + | // Set TCP half-close timeout. |
| 74 | + | dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) |
84 | 75 | | } |
85 | 76 | | |