Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions pulsar/internal/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,17 @@ type connection struct {

log log.Logger

incomingRequestsWG sync.WaitGroup
incomingRequestsCh chan *request
closeCh chan struct{}
readyCh chan struct{}
writeRequestsCh chan *dataRequest
// incomingRequestsWG tracks in-flight SendRequest/SendRequestNoWait
// callers. incomingRequestsLock serialises Add(1) with the Wait() in
// failLeftRequestsWhenClose so a late SendRequest cannot Add to a
// WaitGroup whose counter is already draining (which would trip
// "sync: WaitGroup is reused before previous Wait has returned").
incomingRequestsLock sync.RWMutex
incomingRequestsWG sync.WaitGroup
incomingRequestsCh chan *request
closeCh chan struct{}
readyCh chan struct{}
writeRequestsCh chan *dataRequest

pendingLock sync.Mutex
pendingReqs map[uint64]*request
Expand Down Expand Up @@ -373,6 +379,14 @@ func (c *connection) waitUntilReady() error {
}

func (c *connection) failLeftRequestsWhenClose() {
// Stop new SendRequest/SendRequestNoWait callers from adding to the
// WaitGroup before draining the in-flight ones. Without this barrier,
// a concurrent Add(1) racing with Wait() reaching zero panics with
// "sync: WaitGroup is reused before previous Wait has returned".
c.incomingRequestsLock.Lock()
c.setStateClosed()
c.incomingRequestsLock.Unlock()

// wait for outstanding incoming requests to complete before draining
// and closing the channel
c.incomingRequestsWG.Wait()
Expand Down Expand Up @@ -656,33 +670,37 @@ func (c *connection) checkServerError(err *pb.ServerError) {

func (c *connection) SendRequest(requestID uint64, req *pb.BaseCommand,
callback func(command *pb.BaseCommand, err error)) {
c.incomingRequestsLock.RLock()
if c.getState() == connectionClosed {
c.incomingRequestsLock.RUnlock()
callback(req, ErrConnectionClosed)
return
}
c.incomingRequestsWG.Add(1)
c.incomingRequestsLock.RUnlock()
defer c.incomingRequestsWG.Done()

if c.getState() == connectionClosed {
select {
case <-c.closeCh:
callback(req, ErrConnectionClosed)

} else {
select {
case <-c.closeCh:
callback(req, ErrConnectionClosed)

case c.incomingRequestsCh <- &request{
id: &requestID,
cmd: req,
callback: callback,
}:
}
case c.incomingRequestsCh <- &request{
id: &requestID,
cmd: req,
callback: callback,
}:
}
}

func (c *connection) SendRequestNoWait(req *pb.BaseCommand) error {
c.incomingRequestsWG.Add(1)
defer c.incomingRequestsWG.Done()

c.incomingRequestsLock.RLock()
if c.getState() == connectionClosed {
c.incomingRequestsLock.RUnlock()
return ErrConnectionClosed
}
c.incomingRequestsWG.Add(1)
c.incomingRequestsLock.RUnlock()
defer c.incomingRequestsWG.Done()

select {
case <-c.closeCh:
Expand Down
Loading