Go实现线程安全的缓存

场景

某些函数调用频繁,但其计算却非常耗时,为了避免每次调用时都重新计算一遍,我们需要保存函数的计算结果,这样在对函数进行调用的时候,只需要计算一次,之后的调用可以直接从缓存中返回计算结果。

使用下面的httpGetBody()作为我们需要缓存的函数样例。

func httpGetBody(url string) (interface{}, error) {
        resp, err := http.Get(url)
        if err != nil {
                return nil, err
        }
        defer resp.Body.Close()
        return ioutil.ReadAll(resp.Body) // ReadAll会返回两个结果,一个[]byte数组和一个错误
}

要求

缓存的设计要求是并发安全的,并且要尽量高效。

版本1:使用互斥量实现并发安全

版本1

// Func 是待缓存的函数(即key)
type Func func(key string) (interface{}, error)
// Result 作为缓存结果(即value)
type result struct {
        value interface{}
        err error
}
// 缓存通过调用 f 函数得到的结果
type Memo struct {
        f Func
        cache map[string]result
}

func NewMemo(f Func) *Memo {
        memo := &Memo{f, make(map[string]result)}
        return memo
}

// Get方法,线程不安全
func (memo *Memo) Get(url string) (interface{}, error) {
        res, ok := memo.cache[url]
        if !ok { // 如果缓存中不存在,通过调用memo中的f函数计算出结果,并把结果缓存起来
                res.value, res.err = memo.f(url)
                memo.cache[url] = res
        }
        return res.value, res.err
}

Memo实例会记录需要缓存的函数f(类型为Func),以及缓存内容(里面是一个string到result映射的map)。

这是一个最简单的实现,由于没有加锁,是线程不安全的。我们先对其进行简单的测试。测试函数如下:

var urls = []string {
        "https://www.nowcoder.com/",
        "https://www.nowcoder.com/contestRoom",
        "https://www.nowcoder.com/interview/ai/index",
        "https://www.nowcoder.com/courses",
        "https://www.nowcoder.com/recommend",
        "https://www.nowcoder.com/courses",     // 重复的url,测试缓存效果
        "https://www.nowcoder.com/contestRoom", // 重复的url,测试缓存效果
}

// 单个goroutine,顺序调用
func TestMemoSingle(t *testing.T) {
        m := NewMemo(httpGetBody)
        totalTime := time.Now()
        for _, url := range urls {
                start := time.Now()
                value, err := m.Get(url)
                if err != nil {
                        log.Println(err)
                }
                fmt.Printf("%s, %s, %d bytes\n", url, time.Since(start), len(value.([]byte)))
        }
        fmt.Printf("total time used: %s\n", time.Since(totalTime))
}

// 并发调用
// 使用 sync.WaitGroup 来等待所有的请求都完成再返回
func TestMemoConcurrency(t *testing.T) {
        m := NewMemo(httpGetBody)
        var group sync.WaitGroup
        totalTime := time.Now()
        for _, url := range urls {
                group.Add(1)

                go func(url string) {
                        start := time.Now()
                        value, err := m.Get(url)
                        if err != nil {
                                log.Println(err)
                        }
                        fmt.Printf("%s, %s, %d bytes\n", url, time.Since(start), len(value.([]byte)))

                        group.Done() // equals ==> group.Add(-1)
                }(url)
        }
        group.Wait()
        fmt.Printf("total time used: %s\n", time.Since(totalTime))
}

首先测试单个goroutine顺序执行的情况,测试结果如下:

$ go test -v -run=TestMemoSingle
=== RUN   TestMemoSingle
https://www.nowcoder.com/, 289.8287ms, 95378 bytes
https://www.nowcoder.com/contestRoom, 178.8973ms, 71541 bytes
https://www.nowcoder.com/interview/ai/index, 68.9602ms, 21320 bytes
https://www.nowcoder.com/courses, 148.9146ms, 64304 bytes
https://www.nowcoder.com/recommend, 121.932ms, 90666 bytes
https://www.nowcoder.com/courses, 0s, 64304 bytes     // 可以看到,本次调用直接从缓存中获取结果,耗时为0
https://www.nowcoder.com/contestRoom, 0s, 71541 bytes // 同上
total time used: 809.5305ms
--- PASS: TestMemoSingle (0.81s)
PASS
ok      _/D_/workspace/GoRepo/gopl/ch9/memo1    1.546s

可以清楚的看到,当访问之前已经被访问过的 url 时,可以立刻从缓存中返回结果。我们再来试试看并发访问的情况。

$ go test -v -run=TestMemoConcurrency
=== RUN   TestMemoConcurrency
https://www.nowcoder.com/interview/ai/index, 252.8542ms, 21320 bytes
https://www.nowcoder.com/, 253.8524ms, 95378 bytes
https://www.nowcoder.com/recommend, 279.8401ms, 90666 bytes
https://www.nowcoder.com/courses, 280.8377ms, 64304 bytes
https://www.nowcoder.com/courses, 318.8194ms, 64304 bytes
https://www.nowcoder.com/contestRoom, 359.7913ms, 71541 bytes
https://www.nowcoder.com/contestRoom, 404.7649ms, 71541 bytes
total time used: 404.7649ms
--- PASS: TestMemoConcurrency (0.40s)
PASS
ok      _/D_/workspace/GoRepo/gopl/ch9/memo1    3.034s

并发访问时(请多测试几次),可以看到,总的用时比单个gouroutine顺序访问时少了差不多一半。但访问相同 url 时似乎没有达到缓存的效果。原因很简单嘛,我们在实现Get()方法时,没有加锁限制,因此多个goroutine可能同时访问memo实例,也就是出现了数据竞争

在 Go 中,我们可以利用-race标签,它能帮助我们识别代码中是否出现了数据竞争。比如:

$ go test -v -race -run=TestMemoConcurrency
=== RUN   TestMemoConcurrency
...
==================
WARNING: DATA RACE
Write at 0x00c000078cc0 by goroutine 10: // 在 goroutine 10 中写入
  runtime.mapassign_faststr()
      D:/soft/Go/src/runtime/map_faststr.go:202 +0x0
  _/D_/workspace/GoRepo/gopl/ch9/memo1.(*Memo).Get()
      D:/workspace/GoRepo/gopl/ch9/memo1/memo.go:40 +0x1d5
  _/D_/workspace/GoRepo/gopl/ch9/memo1.TestMemoConcurrency.func1()
      D:/workspace/GoRepo/gopl/ch9/memo1/memo_test.go:46 +0x96

Previous write at 0x00c000078cc0 by goroutine 7: // 在 goroutine 7 中也出现写入
  runtime.mapassign_faststr()
      D:/soft/Go/src/runtime/map_faststr.go:202 +0x0
  _/D_/workspace/GoRepo/gopl/ch9/memo1.(*Memo).Get()
      D:/workspace/GoRepo/gopl/ch9/memo1/memo.go:40 +0x1d5
  _/D_/workspace/GoRepo/gopl/ch9/memo1.TestMemoConcurrency.func1()
      D:/workspace/GoRepo/gopl/ch9/memo1/memo_test.go:46 +0x96
...

FAIL
exit status 1
FAIL    _/D_/workspace/GoRepo/gopl/ch9/memo1    0.699s

可以看到,memo.go 的第40行(对应memo.cache[url] = res)出现了2次,说明有两个goroutine在没有同步干预的情况下更新了cache map。这表明Get不是并发安全的,存在数据竞争。

OK,那我们就设法对Get()方法进行加锁(mutex),最粗暴的方式莫过于如下:

// Get is concurrency-safe.
func (memo *Memo) Get(key string) (value interface{}, err error) {
    memo.mu.Lock()
    res, ok := memo.cache[key]
    if !ok {
        res.value, res.err = memo.f(key)
        memo.cache[key] = res
    }
    memo.mu.Unlock()
    return res.value, res.err
}

这样做当然实现了所谓的“并发安全”,但是也失去了“并发性”,每次对f的调用期间都会持有锁,Get将本来可以并行运行的I/O操作串行化了。显然,这不是我们所希望的。

我们试图降低锁的粒度,查找阶段获取一次,如果查找没有返回任何内容,那么进入更新阶段会再次获取。在这两次获取锁的中间阶段,其它goroutine可以随意使用cache。

func (memo *Memo) Get(key string) (value interface{}, err error) {
    memo.mu.Lock()
    res, ok := memo.cache[key]
    memo.mu.Unlock()
    if !ok {
        res.value, res.err = memo.f(key)

        // Between the two critical sections, several goroutines
        // may race to compute f(key) and update the map.
        memo.mu.Lock()
        memo.cache[key] = res
        memo.mu.Unlock()
    }
    return res.value, res.err
}

这种实现在两个以上的goroutine同一时刻调用Get来请求同样的URL时,会导致同样的url被重复计算。多个goroutine一起查询cache,发现没有值,然后一起调用f这个慢不拉叽的函数。在得到结果后,也都会去更新map。其中一个获得的结果会覆盖掉另一个的结果。理想情况下是应该避免掉多余的工作的,这种“避免”工作一般被称为duplicate suppression(重复抑制/避免)。

版本2:使用“互斥量+channel”实现并发安全机制

该版本的Memo每一个map元素都是指向一个条目的指针。每一个条目包含对函数f的调用结果。与之前不同的是这次entry还包含了一个叫ready的channel。在条目的结果被设置之后,这个channel就会被关闭,以向其它goroutine广播——“现在去读取该条目内的结果是安全的了”。

// Func 是待缓存的函数,作为 key
type Func func(key string) (interface{}, error)

// entry 作为缓存的 value, 除了包含一个结果result,还包含一个channel
type entry struct {
        res result
        ready chan struct{}
}

type result struct {
        value interface{}
        err   error
}

// 缓存通过调用 f 函数得到的结果
type Memo struct {
        f     Func
        mu    sync.Mutex
        cache map[string]*entry
}

func NewMemo(f Func) *Memo {
        memo := &Memo{f: f, cache: make(map[string]*entry)}
        return memo
}

// 使用一个互斥量(即 unbuffered channel)来保护多个goroutine调用Get时的共享map变量
func (memo *Memo) Get(url string) (interface{}, error) {
        memo.mu.Lock()
        e := memo.cache[url]
        if e == nil {
                // 如果查询结果为空,说明这是对该url的第一次查询
                // 因此,让这个goroutine负责计算这个url对应的值
                // 当计算好后,再广播通知所有的其他goroutine,
                // 告诉它们这个url对应的缓存已经存在了,可以直接取用
                e = &entry{ready: make(chan struct{})}
                memo.cache[url] = e // 注意这里只是存入了一个“空的”条目,真正的结果还没计算出来
                memo.mu.Unlock()

                e.res.value, e.res.err = memo.f(url)

                close(e.ready) // broadcast ready condition
        } else {
                // 如果查询到结果非空,则立马先把锁给释放掉
                memo.mu.Unlock()

                // 但是要注意,这里的非空并不代表马上就可以返回结果
                // 因为有可能是其他goroutine还在计算中
                // 因此要等待ready condition
                <-e.ready
        }
        return e.res.value, e.res.err
}

获取互斥锁来保护共享变量cache map,查询map中是否存在指定条目,如果没有找到,那么分配空间插入一个新条目,释放互斥锁。如果条目存在但其值并没有写入完成时(也就是有其它的goroutine在调用 f 这个慢函数),goroutine则必须等待ready之后才能读到条目的结果。ready condition由一个无缓存channel来实现,对无缓存channel的读取操作(即<-e.ready)在channel关闭之前一直是阻塞。

如果没有条目的话,需要向map中插入一个没有准备好的条目,当前正在调用的goroutine就需要负责调用慢函数、更新条目以及向其它所有goroutine广播条目已经ready可读的消息了。

条目中的e.res.value和e.res.err变量是在多个goroutine之间共享的。创建条目的goroutine同时也会设置条目的值,其它goroutine在收到"ready"的广播消息之后立刻会去读取条目的值。尽管会被多个goroutine同时访问,但却并不需要互斥锁。ready channel的关闭一定会发生在其它goroutine接收到广播事件之前,因此第一个goroutine对这些变量的写操作是一定发生在这些读操作之前的。不会发生数据竞争。

版本3:通过goroutine通信实现并发安全

在版本2的实现中,我们使用了一个互斥量来保护多个goroutine调用Get时的共享变量map。在Go中,还有另外一种设计方案——把共享变量map限制在一个单独的goroutine中(我们称这样的goroutine为monitor goroutine),对缓存的查询和写入均通过monitor goroutine进行。

Func、result和entry的声明和之前保持一致,这里不再重复。Memo类型的定义则做了很大的改动,只包含了一个叫做requests的channel,Get的调用者用这个channel来和monitor goroutine通信。

type request struct {
        url string
        // 负责发送响应结果, 只发送, 不接收
        response chan<- result 
}

type Memo struct {
        requests chan request
}

func NewMemo(f Func) *Memo {
        memo := &Memo{requests:make(chan request)}
        go memo.server(f)
        return memo
}

func (memo *Memo) Get(url string) (interface{}, error) {
        response := make(chan result)
        memo.requests <- request{url, response}
        res := <-response
        return res.value, res.err
}

func (memo *Memo) Close() {
        close(memo.requests)
}

上面的Get方法,会创建一个response channel,把它放进request结构中,然后发送给monitor goroutine,然后马上又会接收它。

cache变量被限制在了monitor goroutine中,即server()函数,下面会看到。monitor会在循环中一直读取请求,直到request channel被Close方法关闭。每一个请求都会去查询cache,如果没有找到条目的话,那么就会创建/插入一个新的条目。

func (memo *Memo) server(f Func) {
        cache := make(map[string]*entry)
        for req := range memo.requests {
                e := cache[req.url]
                if e == nil {
                        // this is the first request for this url
                        e = &entry{ready: make(chan struct{})}
                        cache[req.url] = e
                        go e.call(f,req.url)
                }
                go e.deliver(req.response)
        }
}

func (e *entry) call(f Func, url string) {
        // Evaluate the function.
        e.res.value, e.res.err = f(url)
        // broadcast ready condition
        close(e.ready)
}

func (e *entry) deliver(response chan<- result) {
        // wait for the ready condition
        <-e.ready
        // send the result to the client
        response <- e.res
}

和基于互斥量的版本类似,第一个对某个key的请求需要负责去调用函数f并传入这个key,将结果存在条目里,并关闭ready channel来广播条目的ready消息。使用(*entry).call来完成上述工作。

紧接着对同一个key的请求会发现map中已经有了存在的条目,然后会等待结果变为ready,并将结果从response发送给客户端的goroutien。上述工作是用(*entry).deliver来完成的。对call和deliver方法的调用必须让它们在自己的goroutine中进行以确保monitor goroutines不会因此而被阻塞住而没法处理新的请求。

总结

在Go中,我们可以通过使用互斥量(加锁),或者通信来建立并发程序。后者实现起来会难一些,初学也比较难理解。我也理解不深,暂记录于此。


本文是对《The Go Programming Language》 9.7 节的学习笔记,大家去看原文吧~