Projects STRLCPY dnstt Commits 6e5ba30a
🤬
  • Use net.Dial, rather than net.DialTCP, to dial upstream.

    The usual use case for upstream is that it is a localhost IP address and
    port, but it may also be a hostname and port. net.DialTCP resolves the
    hostname once and for all, and only uses one of the hostname's IP
    addresses if there are more than one. net.Dial will try all the IP
    addresses in turn until it is able to establish a connection.
    
    Now upstream is kept as a string variable all the way through the call
    chain. For the sake of usability, we try resolving the address with
    net.ResolveTCPAddr in main, to emit an error or warning right away,
    rather than deferring it to the first stream.
  • Loading...
  • David Fifield committed 3 years ago
    6e5ba30a
    1 parent 064c53e3
  • ■ ■ ■ ■ ■ ■
    dnstt-server/main.go
    skipped 180 lines
    181 181   
    182 182  // handleStream bidirectionally connects a client stream with a TCP socket
    183 183  // addressed by upstream.
    184  -func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error {
    185  - conn, err := net.DialTCP("tcp", nil, upstream)
     184 +func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
     185 + upstreamConn, err := net.Dial("tcp", upstream)
    186 186   if err != nil {
    187 187   return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
    188 188   }
    189  - defer conn.Close()
     189 + defer upstreamConn.Close()
     190 + upstreamTCPConn := upstreamConn.(*net.TCPConn)
    190 191   
    191 192   var wg sync.WaitGroup
    192 193   wg.Add(2)
    193 194   go func() {
    194 195   defer wg.Done()
    195  - _, err := io.Copy(stream, conn)
     196 + _, err := io.Copy(stream, upstreamTCPConn)
    196 197   if err == io.EOF {
    197 198   // smux Stream.Write may return io.EOF.
    198 199   err = nil
    skipped 1 lines
    200 201   if err != nil {
    201 202   log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
    202 203   }
    203  - conn.CloseRead()
     204 + upstreamTCPConn.CloseRead()
    204 205   stream.Close()
    205 206   }()
    206 207   go func() {
    207 208   defer wg.Done()
    208  - _, err := io.Copy(conn, stream)
     209 + _, err := io.Copy(upstreamTCPConn, stream)
    209 210   if err == io.EOF {
    210 211   // smux Stream.WriteTo may return io.EOF.
    211 212   err = nil
    skipped 1 lines
    213 214   if err != nil && err != io.ErrClosedPipe {
    214 215   log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
    215 216   }
    216  - conn.CloseWrite()
     217 + upstreamTCPConn.CloseWrite()
    217 218   }()
    218 219   wg.Wait()
    219 220   
    skipped 2 lines
    222 223   
    223 224  // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
    224 225  // then awaits smux streams. It passes each stream to handleStream.
    225  -func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream *net.TCPAddr) error {
     226 +func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream string) error {
    226 227   // Put a Noise channel on top of the KCP conn.
    227 228   rw, err := noise.NewServer(conn, privkey, pubkey)
    228 229   if err != nil {
    skipped 34 lines
    263 264   
    264 265  // acceptSessions listens for incoming KCP connections and passes them to
    265 266  // acceptStreams.
    266  -func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream *net.TCPAddr) error {
     267 +func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream string) error {
    267 268   for {
    268 269   conn, err := ln.AcceptKCP()
    269 270   if err != nil {
    skipped 469 lines
    739 740   return low
    740 741  }
    741 742   
    742  -func run(privkey, pubkey []byte, domain dns.Name, upstream net.Addr, dnsConn net.PacketConn) error {
     743 +func run(privkey, pubkey []byte, domain dns.Name, upstream string, dnsConn net.PacketConn) error {
    743 744   defer dnsConn.Close()
    744 745   
    745 746   log.Printf("pubkey %x", pubkey)
    skipped 24 lines
    770 771   }
    771 772   defer ln.Close()
    772 773   go func() {
    773  - err := acceptSessions(ln, privkey, pubkey, mtu, upstream.(*net.TCPAddr))
     774 + err := acceptSessions(ln, privkey, pubkey, mtu, upstream)
    774 775   if err != nil {
    775 776   log.Printf("acceptSessions: %v", err)
    776 777   }
    skipped 65 lines
    842 843   fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
    843 844   os.Exit(1)
    844 845   }
    845  - upstream, err := net.ResolveTCPAddr("tcp", flag.Arg(1))
    846  - if err != nil {
    847  - fmt.Fprintf(os.Stderr, "cannot resolve %+q: %v\n", flag.Arg(1), err)
    848  - os.Exit(1)
     846 + upstream := flag.Arg(1)
     847 + // We keep upstream as a string in order to eventually pass it
     848 + // to net.Dial in handleStream. But for the sake of displaying
     849 + // an error or warning at startup, rather than only when the
     850 + // first stream occurs, we apply some parsing and name
     851 + // resolution checks here.
     852 + {
     853 + upstreamHost, _, err := net.SplitHostPort(upstream)
     854 + if err != nil {
     855 + // host:port format is required in all cases, so
     856 + // this is a fatal error.
     857 + fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: %v\n", upstream, err)
     858 + os.Exit(1)
     859 + }
     860 + upstreamIPAddr, err := net.ResolveIPAddr("ip", upstreamHost)
     861 + if err != nil {
     862 + // Failure to resolve the host portion is only a
     863 + // warning. The name will be re-resolved on each
     864 + // net.Dial in handleStream.
     865 + log.Printf("warning: cannot resolve upstream host %+q: %v", upstreamHost, err)
     866 + } else if upstreamIPAddr.IP == nil {
     867 + // Handle the special case of an empty string
     868 + // for the host portion, which resolves to a nil
     869 + // IP. This is a fatal error as we will not be
     870 + // able to dial this address.
     871 + fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: missing host in address\n", upstream)
     872 + os.Exit(1)
     873 + }
    849 874   }
    850 875   
    851 876   if udpAddr == "" {
    skipped 52 lines
Please wait...
Page is in error, reload to recover