socket_test.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package socket
  2. import (
  3. "errors"
  4. "io"
  5. "io/fs"
  6. "net"
  7. "os"
  8. "runtime"
  9. "strings"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "gotest.tools/v3/assert"
  14. "gotest.tools/v3/poll"
  15. )
  16. func TestPluginServer(t *testing.T) {
  17. t.Run("connection closes with EOF when server closes", func(t *testing.T) {
  18. called := make(chan struct{})
  19. srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
  20. assert.NilError(t, err)
  21. assert.Assert(t, srv != nil, "returned nil server but no error")
  22. addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
  23. assert.NilError(t, err, "failed to resolve server address")
  24. conn, err := net.DialUnix("unix", nil, addr)
  25. assert.NilError(t, err, "failed to dial returned server")
  26. defer conn.Close()
  27. done := make(chan error, 1)
  28. go func() {
  29. _, err := conn.Read(make([]byte, 1))
  30. done <- err
  31. }()
  32. select {
  33. case <-called:
  34. case <-time.After(10 * time.Millisecond):
  35. t.Fatal("handler not called")
  36. }
  37. srv.Close()
  38. select {
  39. case err := <-done:
  40. if !errors.Is(err, io.EOF) {
  41. t.Fatalf("exepcted EOF error, got: %v", err)
  42. }
  43. case <-time.After(10 * time.Millisecond):
  44. }
  45. })
  46. t.Run("allows reconnects", func(t *testing.T) {
  47. var calls int32
  48. h := func(_ net.Conn) {
  49. atomic.AddInt32(&calls, 1)
  50. }
  51. srv, err := NewPluginServer(h)
  52. assert.NilError(t, err)
  53. defer srv.Close()
  54. assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
  55. addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
  56. assert.NilError(t, err, "failed to resolve server address")
  57. waitForCalls := func(n int) {
  58. poll.WaitOn(t, func(t poll.LogT) poll.Result {
  59. if atomic.LoadInt32(&calls) == int32(n) {
  60. return poll.Success()
  61. }
  62. return poll.Continue("waiting for handler to be called")
  63. })
  64. }
  65. otherConn, err := net.DialUnix("unix", nil, addr)
  66. assert.NilError(t, err, "failed to dial returned server")
  67. otherConn.Close()
  68. waitForCalls(1)
  69. conn, err := net.DialUnix("unix", nil, addr)
  70. assert.NilError(t, err, "failed to redial server")
  71. defer conn.Close()
  72. waitForCalls(2)
  73. // and again but don't close the existing connection
  74. conn2, err := net.DialUnix("unix", nil, addr)
  75. assert.NilError(t, err, "failed to redial server")
  76. defer conn2.Close()
  77. waitForCalls(3)
  78. srv.Close()
  79. // now make sure we get EOF on the existing connections
  80. buf := make([]byte, 1)
  81. _, err = conn.Read(buf)
  82. assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
  83. _, err = conn2.Read(buf)
  84. assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
  85. })
  86. t.Run("does not leak sockets to local directory", func(t *testing.T) {
  87. srv, err := NewPluginServer(nil)
  88. assert.NilError(t, err)
  89. assert.Check(t, srv != nil, "returned nil server but no error")
  90. checkDirNoNewPluginServer(t)
  91. addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
  92. assert.NilError(t, err, "failed to resolve server address")
  93. _, err = net.DialUnix("unix", nil, addr)
  94. assert.NilError(t, err, "failed to dial returned server")
  95. checkDirNoNewPluginServer(t)
  96. })
  97. t.Run("does not panic on Close if server is nil", func(t *testing.T) {
  98. var srv *PluginServer
  99. defer func() {
  100. if r := recover(); r != nil {
  101. t.Errorf("panicked on Close")
  102. }
  103. }()
  104. err := srv.Close()
  105. assert.NilError(t, err)
  106. })
  107. }
  108. func checkDirNoNewPluginServer(t *testing.T) {
  109. t.Helper()
  110. files, err := os.ReadDir(".")
  111. assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")
  112. for _, f := range files {
  113. info, err := f.Info()
  114. assert.NilError(t, err, "failed to check file info")
  115. // check for a socket with `docker_cli_` in the name (from `SetupConn()`)
  116. if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket {
  117. t.Fatal("found socket in a local directory")
  118. }
  119. }
  120. }
  121. func TestConnectAndWait(t *testing.T) {
  122. t.Run("calls cancel func on EOF", func(t *testing.T) {
  123. srv, err := NewPluginServer(nil)
  124. assert.NilError(t, err, "failed to setup server")
  125. defer srv.Close()
  126. done := make(chan struct{})
  127. t.Setenv(EnvKey, srv.Addr().String())
  128. cancelFunc := func() {
  129. done <- struct{}{}
  130. }
  131. ConnectAndWait(cancelFunc)
  132. select {
  133. case <-done:
  134. t.Fatal("unexpectedly done")
  135. default:
  136. }
  137. srv.Close()
  138. select {
  139. case <-done:
  140. case <-time.After(10 * time.Millisecond):
  141. t.Fatal("cancel function not closed after 10ms")
  142. }
  143. })
  144. // TODO: this test cannot be executed with `t.Parallel()`, due to
  145. // relying on goroutine numbers to ensure correct behaviour
  146. t.Run("connect goroutine exits after EOF", func(t *testing.T) {
  147. runtime.LockOSThread()
  148. defer runtime.UnlockOSThread()
  149. srv, err := NewPluginServer(nil)
  150. assert.NilError(t, err, "failed to setup server")
  151. defer srv.Close()
  152. t.Setenv(EnvKey, srv.Addr().String())
  153. runtime.Gosched()
  154. numGoroutines := runtime.NumGoroutine()
  155. ConnectAndWait(func() {})
  156. runtime.Gosched()
  157. poll.WaitOn(t, func(t poll.LogT) poll.Result {
  158. // +1 goroutine for the poll.WaitOn
  159. // +1 goroutine for the connect goroutine
  160. if runtime.NumGoroutine() < numGoroutines+1+1 {
  161. return poll.Continue("waiting for connect goroutine to spawn")
  162. }
  163. return poll.Success()
  164. }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(500*time.Millisecond))
  165. srv.Close()
  166. runtime.Gosched()
  167. poll.WaitOn(t, func(t poll.LogT) poll.Result {
  168. // +1 goroutine for the poll.WaitOn
  169. if runtime.NumGoroutine() > numGoroutines+1 {
  170. return poll.Continue("waiting for connect goroutine to exit")
  171. }
  172. return poll.Success()
  173. }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(500*time.Millisecond))
  174. })
  175. }