Skip to content

Commit 8238757

Browse files
committed
fix(network): clean up resources on panic in callback to prevent memory leaks
1 parent 5650ade commit 8238757

4 files changed

Lines changed: 90 additions & 36 deletions

File tree

cluster/acceptor.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cluster
33
import (
44
"context"
55
"net"
6+
"sync/atomic"
67

78
"github.com/lonng/nano/cluster/clusterpb"
89
"github.com/lonng/nano/internal/env"
@@ -22,17 +23,20 @@ type acceptor struct {
2223
lastMid uint64
2324
rpcHandler rpcHandler
2425
gateAddr string
26+
state atomic.Int32 // current acceptor state
2527
}
2628

2729
// newAcceptor 构造函数
2830
func newAcceptor(sid int64, node *Node, gateClient clusterpb.GateClient, rpcHandler rpcHandler, gateAddr string) *acceptor {
29-
return &acceptor{
31+
a := &acceptor{
3032
sid: sid,
3133
node: node,
3234
gateClient: gateClient,
3335
rpcHandler: rpcHandler,
3436
gateAddr: gateAddr,
3537
}
38+
a.state.Store(statusWorking)
39+
return a
3640
}
3741

3842
// RemoteAddr 返回一个假的地址
@@ -47,6 +51,9 @@ func (a *acceptor) LastMid() uint64 {
4751

4852
// RPC 调用集群内的服务
4953
func (a *acceptor) RPC(route string, v any) error {
54+
if a.status() == statusClosed {
55+
return ErrBrokenPipe
56+
}
5057
// TODO: buffer
5158
data, err := env.Marshal(v)
5259
if err != nil {
@@ -63,6 +70,9 @@ func (a *acceptor) RPC(route string, v any) error {
6370

6471
// Push 调用 Gate, 推送数据给客户端
6572
func (a *acceptor) Push(route string, v any) error {
73+
if a.status() == statusClosed {
74+
return ErrBrokenPipe
75+
}
6676
// TODO: buffer
6777
data, err := env.Marshal(v)
6878
if err != nil {
@@ -84,6 +94,9 @@ func (a *acceptor) Response(v any) error {
8494

8595
// ResponseMid 调用 Gate, 返回响应数据给客户端
8696
func (a *acceptor) ResponseMid(mid uint64, v any) error {
97+
if a.status() == statusClosed {
98+
return ErrBrokenPipe
99+
}
87100
// TODO: buffer
88101
data, err := env.Marshal(v)
89102
if err != nil {
@@ -100,10 +113,13 @@ func (a *acceptor) ResponseMid(mid uint64, v any) error {
100113

101114
// Close 集群模式下, Worker 节点关闭会话, 通知 Gate 也关闭(主动关闭)
102115
func (a *acceptor) Close() error {
116+
if !a.state.CompareAndSwap(statusWorking, statusClosed) {
117+
return ErrCloseClosedSession
118+
}
103119
// TODO: buffer
104120
// 先删除
105121
s, found := a.node.delSession(a.sid)
106-
// 通知 gate
122+
// 通知 Gate 关闭连接
107123
request := &clusterpb.CloseSessionRequest{
108124
SessionId: a.sid,
109125
}
@@ -114,3 +130,8 @@ func (a *acceptor) Close() error {
114130
}
115131
return err
116132
}
133+
134+
// status 获取当前状态
135+
func (a *acceptor) status() int32 {
136+
return a.state.Load()
137+
}

cluster/agent.go

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"errors"
55
"fmt"
66
"net"
7-
"reflect"
7+
"sync"
88
"sync/atomic"
99
"time"
1010

@@ -33,17 +33,19 @@ var _ session.NetworkEntity = (*agent)(nil)
3333
// agent 与客户端直接通信的网络对象(单点模式中的节点, 或集群模式的网关)
3434
type agent struct {
3535
// regular agent member
36-
session *session.Session // session
37-
conn net.Conn // low-level conn fd
38-
lastMid uint64 // last message id
39-
state atomic.Int32 // current agent state
40-
chDie chan struct{} // wait for close
41-
chSend chan pendingMessage // push message queue
42-
lastAt int64 // last heartbeat unix time stamp
43-
decoder *packet.Decoder // binary decoder
44-
pipeline pipeline.Pipeline //
45-
rpcHandler rpcHandler //
46-
srv reflect.Value // cached session reflect.Value
36+
session *session.Session // session
37+
conn net.Conn // low-level conn fd
38+
lastMid uint64 // last message id
39+
chDie chan struct{} // wait for close
40+
chSend chan pendingMessage // push message queue
41+
lastAt atomic.Int64 // last heartbeat unix time stamp
42+
decoder *packet.Decoder // binary decoder
43+
pipeline pipeline.Pipeline //
44+
rpcHandler rpcHandler //
45+
writeReady atomic.Bool // write 协程是否已经启动
46+
connCloseOnce sync.Once // 确保 conn 只关闭一次
47+
chanCloseOnce sync.Once // 确保 chDie 和 chSend 只关闭一次
48+
state atomic.Int32 // current agent state
4749
}
4850

4951
// pendingMessage 待发送的消息
@@ -59,18 +61,17 @@ func newAgent(conn net.Conn, pipeline pipeline.Pipeline, rpcHandler rpcHandler)
5961
a := &agent{
6062
conn: conn,
6163
chDie: make(chan struct{}),
62-
lastAt: time.Now().Unix(),
6364
chSend: make(chan pendingMessage, agentWriteBacklog),
6465
decoder: packet.NewDecoder(),
6566
pipeline: pipeline,
6667
rpcHandler: rpcHandler,
6768
}
69+
a.lastAt.Store(time.Now().Unix())
6870
a.state.Store(statusStart)
6971

7072
// binding session
7173
s := session.New(a)
7274
a.session = s
73-
a.srv = reflect.ValueOf(s)
7475
return a
7576
}
7677

@@ -157,7 +158,7 @@ func (a *agent) ResponseMid(mid uint64, v any) error {
157158
return a.send(pendingMessage{typ: message.Response, mid: mid, payload: v})
158159
}
159160

160-
// Close 关闭低级连接, 任何被阻塞读取或写入作都将被取消阻塞并返回错误
161+
// Close 设置关闭状态, 发送关闭信号; 如果 write 协程未就绪, 直接关闭底层连接; 否则由 write 协程 flush 完数据后关闭
161162
func (a *agent) Close() error {
162163
if a.status() == statusClosed {
163164
return ErrCloseClosedSession
@@ -168,19 +169,38 @@ func (a *agent) Close() error {
168169
log.Info("Session closing, ID=%d, UID=%d, IP=%s", a.session.ID(), a.session.UID(), a.conn.RemoteAddr())
169170
}
170171

171-
// 防止关闭已经是关闭状态的 chan
172-
select {
173-
case <-a.chDie:
174-
// expect
175-
default:
176-
close(a.chDie)
172+
// 关闭 chan, 发出停止信号
173+
a.closeChanOnce()
174+
175+
// 如果 write 协程已经启动, 则不需要关闭底层连接, 因为 write 协程 flush 完会自动处理
176+
if a.writeReady.Load() {
177+
return nil
177178
}
178-
return nil
179+
180+
// 如果 write 协程还未启动, 则直接关闭底层连接
181+
return a.closeConnOnce()
179182
}
180183

181184
// String 返回描述信息
182185
func (a *agent) String() string {
183-
return fmt.Sprintf("Remote=%s, LastTime=%d", a.conn.RemoteAddr().String(), atomic.LoadInt64(&a.lastAt))
186+
return fmt.Sprintf("Remote=%s, LastTime=%d", a.conn.RemoteAddr().String(), a.lastAt.Load())
187+
}
188+
189+
// closeChanOnce 确保 chan 只关闭一次
190+
func (a *agent) closeChanOnce() {
191+
a.chanCloseOnce.Do(func() {
192+
close(a.chDie)
193+
close(a.chSend)
194+
})
195+
}
196+
197+
// closeConnOnce 确保 conn 只关闭一次
198+
func (a *agent) closeConnOnce() error {
199+
var err error
200+
a.connCloseOnce.Do(func() {
201+
err = a.conn.Close()
202+
})
203+
return err
184204
}
185205

186206
// status 获取当前状态
@@ -202,28 +222,33 @@ func (a *agent) write() {
202222
// clean func
203223
defer func() {
204224
ticker.Stop()
205-
close(a.chSend)
206225
close(chWrite)
207-
//非强制退出, 则将所有待发送的消息写入底层连接
226+
// 关闭 chan, 必须关闭 chan 后才能执行 flush, 否则阻塞
227+
a.closeChanOnce()
228+
// 非强制退出, 则将所有待发送的消息写入底层连接
208229
if !forceQuit {
209230
a.flush(chWrite)
210231
}
211-
//更改 agent 状态, 必须先更改状态再关闭底层连接
232+
// 更改 agent 状态, 必须先更改状态再关闭底层连接
212233
_ = a.Close()
213-
//关闭底层连接, 此时 conn.Read() 将返回错误, 因上一步已经把状态关闭, 所以读协程会跳过日志退出
214-
_ = a.conn.Close()
234+
// 关闭底层连接, 此时 conn.Read() 将返回错误, 因上一步已经把状态关闭, 所以读协程会跳过日志退出
235+
_ = a.closeConnOnce()
215236
if env.Debug {
216237
log.Info("Session write goroutine exit, SessionID=%d, UID=%d", a.session.ID(), a.session.UID())
217238
}
218239
}()
219240

241+
// 标记 write 协程已经就绪
242+
a.writeReady.Store(true)
243+
220244
for {
221245
select {
222246
// 心跳检测
223247
case <-ticker.C:
224248
deadline := time.Now().Add(-2 * heartbeat).Unix()
225-
if atomic.LoadInt64(&a.lastAt) < deadline {
226-
log.Info("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&a.lastAt), deadline)
249+
lastAt := a.lastAt.Load()
250+
if lastAt < deadline {
251+
log.Info("Session heartbeat timeout, LastTime=%d, Deadline=%d", lastAt, deadline)
227252
return
228253
}
229254
chWrite <- getHbd()

cluster/handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ func (h *LocalHandler) processPacket(agent *agent, p *packet.Packet) error {
299299
return errors.New("invalid packet type")
300300
}
301301

302-
agent.lastAt = time.Now().Unix()
302+
agent.lastAt.Store(time.Now().Unix())
303303
return nil
304304
}
305305

session/event.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ func (lt *event) FireSessionCreated(s *Session) {
3939
return
4040
}
4141

42+
// 回调过程中如果发生 panic, 需要关闭底层连接以防止连接泄露
43+
defer func() {
44+
if r := recover(); r != nil {
45+
s.Close()
46+
panic(r) // 重新抛出 panic
47+
}
48+
}()
49+
4250
for _, fn := range lt.onSessionCreated {
4351
fn(s)
4452
}
@@ -55,12 +63,12 @@ func (lt *event) FireSessionClosed(s *Session) {
5563
return
5664
}
5765

66+
// 执行完回调, 清除会话关联的数据
67+
defer s.Clear()
68+
5869
for _, fn := range lt.onSessionClosed {
5970
fn(s)
6071
}
61-
62-
// 执行完回调, 清除会话关联的数据
63-
s.Clear()
6472
}
6573

6674
// MessagePushing 设置消息推送前的回调, 可以修改消息内容和路由

0 commit comments

Comments
 (0)