Go 语言中 context 包的深度解析与应用场景

所有的期待和失望都是因为你把自己看得太重要了

Posted by yishuifengxiao on 2024-08-03

Context 基础概念

Context 的核心作用

  • 取消传播:在调用链中传递取消信号
  • 超时控制:设置操作的截止时间
  • 值传递:在调用链中安全传递请求范围数据
  • 截止时间:设置操作的绝对过期时间点

创建 Context 的四种方式

// 1. 创建根Context
ctx := context.Background()

// 2. 创建TODO占位Context
ctx = context.TODO()

// 3. 创建可取消的Context
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // 确保资源释放

// 4. 创建带超时的Context
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// 5. 创建带截止时间的Context
deadline := time.Now().Add(3 * time.Second)
ctx, cancel = context.WithDeadline(context.Background(), deadline)
defer cancel()

// 6. 创建带值的Context
ctx = context.WithValue(context.Background(), "requestID", "12345")

Context 的五大核心应用场景

请求取消传播

场景:当用户取消操作时,中断整个调用链

func handleRequest(ctx context.Context) {
// 启动数据库查询
resultCh := make(chan Result)
go queryDatabase(ctx, resultCh)

// 启动外部服务调用
serviceCh := make(chan ServiceResponse)
go callExternalService(ctx, serviceCh)

select {
case res := <-resultCh:
process(res)
case svc := <-serviceCh:
process(svc)
case <-ctx.Done():
log.Println("请求被取消:", ctx.Err())
cleanupResources()
}
}

func queryDatabase(ctx context.Context, out chan<- Result) {
// 模拟长时间查询
select {
case <-time.After(3 * time.Second):
out <- Result{Data: "查询结果"}
case <-ctx.Done():
log.Println("数据库查询取消")
return
}
}

超时控制

场景:确保操作在指定时间内完成

func fetchWithTimeout(url string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}

服务间传递元数据

场景:在微服务架构中传递请求ID、认证信息等

type contextKey string

const (
RequestIDKey contextKey = "requestID"
AuthTokenKey contextKey = "authToken"
)

func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 创建新Context并添加值
ctx := r.Context()
ctx = context.WithValue(ctx, RequestIDKey, uuid.New().String())
ctx = context.WithValue(ctx, AuthTokenKey, r.Header.Get("Authorization"))

// 传递新Context
next.ServeHTTP(w, r.WithContext(ctx))
})
}

func handler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// 获取值
requestID, ok := ctx.Value(RequestIDKey).(string)
if !ok {
requestID = "unknown"
}

authToken, _ := ctx.Value(AuthTokenKey).(string)

log.Printf("[%s] 处理请求, Token: %s", requestID, authToken)
// ...
}

分布式追踪

场景:跨服务传递追踪信息

func callDownstreamService(ctx context.Context, serviceURL string) {
// 从Context获取追踪信息
traceID := ctx.Value(TraceIDKey).(string)
spanID := generateSpanID()

// 创建下游请求
req, _ := http.NewRequest("GET", serviceURL, nil)
req = req.WithContext(ctx)

// 添加追踪头
req.Header.Set("X-Trace-ID", traceID)
req.Header.Set("X-Span-ID", spanID)

// 执行请求
resp, err := http.DefaultClient.Do(req)
// ...
}

资源清理协调

场景:确保所有相关资源在取消时被清理

func processTask(ctx context.Context) error {
// 创建资源
dbConn, err := openDatabaseConnection()
if err != nil {
return err
}
defer dbConn.Close()

file, err := os.Create("data.txt")
if err != nil {
return err
}
defer file.Close()

// 监控Context取消
go func() {
<-ctx.Done()
log.Println("任务取消,清理资源")
dbConn.Cancel() // 取消数据库操作
file.Close() // 立即关闭文件
}()

// 处理任务
if err := dbConn.Execute(ctx, "SELECT ..."); err != nil {
return err
}

// 写入文件
if _, err := file.Write(...); err != nil {
return err
}

return nil
}

Context 的高级用法

组合多个 Context

func combineContexts(parent context.Context, timeout time.Duration) context.Context {
// 创建带超时的Context
timeoutCtx, cancel := context.WithTimeout(parent, timeout)

// 创建带值的Context
valueCtx := context.WithValue(timeoutCtx, "combined", true)

// 返回可取消的组合Context
return &combinedContext{
Context: valueCtx,
cancel: cancel,
}
}

type combinedContext struct {
context.Context
cancel context.CancelFunc
}

func (c *combinedContext) Cancel() {
c.cancel()
}

Context 树管理

type ContextTree struct {
root context.Context
nodes map[string]context.Context
mu sync.Mutex
}

func NewContextTree() *ContextTree {
return &ContextTree{
root: context.Background(),
nodes: make(map[string]context.Context),
}
}

func (t *ContextTree) AddBranch(name string, opts ...ContextOption) {
t.mu.Lock()
defer t.mu.Unlock()

parent := t.root
if existing, ok := t.nodes[name]; ok {
parent = existing
}

ctx := parent
for _, opt := range opts {
ctx = opt(ctx)
}

t.nodes[name] = ctx
}

func (t *ContextTree) CancelBranch(name string) {
t.mu.Lock()
defer t.mu.Unlock()

if node, ok := t.nodes[name]; ok {
if cancel, ok := node.(interface{ Cancel() }); ok {
cancel.Cancel()
}
delete(t.nodes, name)
}
}

自定义 Context 实现

type customContext struct {
context.Context
extraData map[string]interface{}
done chan struct{}
mu sync.RWMutex
}

func NewCustomContext(parent context.Context) *customContext {
ctx, cancel := context.WithCancel(parent)
return &customContext{
Context: ctx,
extraData: make(map[string]interface{}),
done: make(chan struct{}),
}
}

func (c *customContext) Deadline() (deadline time.Time, ok bool) {
return c.Context.Deadline()
}

func (c *customContext) Done() <-chan struct{} {
return c.done
}

func (c *customContext) Err() error {
select {
case <-c.done:
return context.Canceled
default:
return nil
}
}

func (c *customContext) Value(key interface{}) interface{} {
if k, ok := key.(string); ok {
c.mu.RLock()
defer c.mu.RUnlock()
return c.extraData[k]
}
return c.Context.Value(key)
}

func (c *customContext) SetValue(key, value interface{}) {
if k, ok := key.(string); ok {
c.mu.Lock()
defer c.mu.Unlock()
c.extraData[k] = value
}
}

func (c *customContext) Cancel() {
close(c.done)
if cancel, ok := c.Context.(context.CancelFunc); ok {
cancel()
}
}

Context 的最佳实践

Context 使用原则

  1. 作为第一个参数:Context 应该是函数的第一个参数

    func DoSomething(ctx context.Context, arg1, arg2 string) error
  2. 不存储 Context:避免在结构体中存储 Context

    // 错误方式
    type Server struct {
    ctx context.Context
    }

    // 正确方式
    func (s *Server) HandleRequest(ctx context.Context) {...}
  3. 传递派生 Context:使用 WithCancel, WithTimeout 等派生新 Context

    func Process(ctx context.Context) {
    childCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
    defer cancel()
    // 使用 childCtx
    }
  4. 及时取消:使用 defer cancel() 确保资源释放

    ctx, cancel := context.WithCancel(context.Background())
    defer cancel() // 确保函数返回时取消

Context 与并发模式

工作池模式

func workerPool(ctx context.Context, numWorkers int) {
taskCh := make(chan Task)
var wg sync.WaitGroup

// 启动工作池
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for {
select {
case task := <-taskCh:
processTask(ctx, task)
case <-ctx.Done():
return // Context取消时退出
}
}
}(i)
}

// 分发任务
for _, task := range fetchTasks() {
select {
case taskCh <- task:
case <-ctx.Done():
break // Context取消时停止分发
}
}

close(taskCh)
wg.Wait()
}

扇出/扇入模式

func processPipeline(ctx context.Context, input <-chan Data) <-chan Result {
// 阶段1: 数据清洗
cleaned := cleanData(ctx, input)

// 阶段2: 并行处理
workers := make([]<-chan Result, 5)
for i := 0; i < 5; i++ {
workers[i] = processWorker(ctx, cleaned)
}

// 阶段3: 结果聚合
return mergeResults(ctx, workers...)
}

func cleanData(ctx context.Context, in <-chan Data) <-chan Data {
out := make(chan Data)
go func() {
defer close(out)
for data := range in {
if !isValid(data) {
continue
}
select {
case out <- data:
case <-ctx.Done():
return
}
}
}()
return out
}

Context 陷阱与避免

陷阱1:未检查 Done()

// 错误方式:可能阻塞
func processBlocking(ctx context.Context) {
result := longRunningOperation() // 阻塞操作
select {
case <-ctx.Done():
return
default:
use(result)
}
}

// 正确方式:检查Context
func processNonBlocking(ctx context.Context) {
resultCh := make(chan Result)
go func() { resultCh <- longRunningOperation() }()

select {
case res := <-resultCh:
use(res)
case <-ctx.Done():
// 取消操作
}
}

陷阱2:Value 类型安全

// 错误方式:直接类型断言
token := ctx.Value("token").(string) // panic风险

// 正确方式:安全断言
if token, ok := ctx.Value("token").(string); ok {
// 使用token
} else {
// 处理缺失情况
}

陷阱3:过度使用 Value

// 错误方式:传递业务参数
ctx = context.WithValue(ctx, "userID", 12345)

// 正确方式:保留给请求范围数据
ctx = context.WithValue(ctx, "requestID", "req-123")

Context 在不同领域的应用

HTTP 服务

func main() {
mux := http.NewServeMux()
mux.HandleFunc("/api", apiHandler)

server := &http.Server{
Addr: ":8080",
Handler: mux,
}

// 优雅关闭
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

if err := server.Shutdown(ctx); err != nil {
log.Fatal("强制关闭服务器:", err)
}
}()

if err := server.ListenAndServe(); err != http.ErrServerClosed {
log.Fatal("服务器错误:", err)
}
}

func apiHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// 设置处理超时
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()

// 处理请求
result, err := processRequest(ctx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
http.Error(w, "请求超时", http.StatusGatewayTimeout)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

json.NewEncoder(w).Encode(result)
}

gRPC 服务

type server struct {
pb.UnimplementedMyServiceServer
}

func (s *server) GetData(ctx context.Context, req *pb.DataRequest) (*pb.DataResponse, error) {
// 检查客户端是否断开连接
if ctx.Err() == context.Canceled {
return nil, status.Error(codes.Canceled, "客户端取消请求")
}

// 设置处理超时
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

// 处理请求
data, err := fetchData(ctx, req.Id)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, status.Error(codes.DeadlineExceeded, "处理超时")
}
return nil, status.Errorf(codes.Internal, "获取数据错误: %v", err)
}

return &pb.DataResponse{Data: data}, nil
}

func main() {
lis, err := net.Listen("tcp", ":50051")
if err != nil {
log.Fatalf("监听失败: %v", err)
}

s := grpc.NewServer()
pb.RegisterMyServiceServer(s, &server{})

if err := s.Serve(lis); err != nil {
log.Fatalf("服务失败: %v", err)
}
}

数据库操作

type Database struct {
pool *pgxpool.Pool
}

func (db *Database) QueryUser(ctx context.Context, userID int) (*User, error) {
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()

var user User
err := db.pool.QueryRow(ctx,
"SELECT id, name, email FROM users WHERE id = $1", userID).
Scan(&user.ID, &user.Name, &user.Email)

if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, fmt.Errorf("数据库查询超时")
}
return nil, fmt.Errorf("查询失败: %w", err)
}

return &user, nil
}

func main() {
config, _ := pgxpool.ParseConfig("postgres://user:pass@localhost/db")
pool, _ := pgxpool.ConnectConfig(context.Background(), config)
defer pool.Close()

db := &Database{pool: pool}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

user, err := db.QueryUser(ctx, 42)
if err != nil {
log.Fatal(err)
}

fmt.Printf("用户: %+v\n", user)
}

分布式任务调度

type TaskScheduler struct {
tasks chan Task
workers int
}

func (s *TaskScheduler) Start(ctx context.Context) {
var wg sync.WaitGroup
wg.Add(s.workers)

for i := 0; i < s.workers; i++ {
go func(id int) {
defer wg.Done()
for {
select {
case task := <-s.tasks:
if err := task.Execute(ctx); err != nil {
log.Printf("任务%d执行失败: %v", task.ID, err)
}
case <-ctx.Done():
log.Printf("工作者%d收到关闭信号", id)
return
}
}
}(i)
}

// 等待所有工作者退出
wg.Wait()
log.Println("调度器已停止")
}

func main() {
scheduler := &TaskScheduler{
tasks: make(chan Task, 100),
workers: 10,
}

ctx, cancel := context.WithCancel(context.Background())

// 添加任务
go func() {
for i := 0; i < 100; i++ {
scheduler.tasks <- Task{ID: i}
}
}()

// 启动调度器
go scheduler.Start(ctx)

// 10秒后停止调度器
time.Sleep(10 * time.Second)
cancel()
log.Println("发送关闭信号")
}

Context 性能优化

减少 Context 创建

// 优化前:每次请求创建新Context
func handleRequest(r *http.Request) {
ctx := context.WithValue(r.Context(), "requestID", genID())
// ...
}

// 优化后:复用Context
var baseCtx = context.Background()

func handleRequest(r *http.Request) {
ctx := context.WithValue(baseCtx, "requestID", genID())
// ...
}

高效值存储

type contextKey struct{}

var requestIDKey = contextKey{}

// 存储值
ctx := context.WithValue(parent, requestIDKey, "12345")

// 获取值
if id, ok := ctx.Value(requestIDKey).(string); ok {
// ...
}

避免深层 Context 链

// 创建扁平Context而不是深层嵌套
func createContext() context.Context {
ctx := context.Background()
ctx = context.WithValue(ctx, "key1", "value1")
ctx = context.WithValue(ctx, "key2", "value2")
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
// 而不是嵌套创建
return ctx
}

Context 池化

type ContextPool struct {
pool sync.Pool
}

func NewContextPool() *ContextPool {
return &ContextPool{
pool: sync.Pool{
New: func() interface{} {
return context.Background()
},
},
}
}

func (p *ContextPool) Get() context.Context {
return p.pool.Get().(context.Context)
}

func (p *ContextPool) Put(ctx context.Context) {
// 重置Context状态
if cancel, ok := ctx.(interface{ Cancel() }); ok {
cancel.Cancel()
}
p.pool.Put(ctx)
}