0. 概述

最近在写一段代码的时候,想保存中间状态的一些 slice,但是,不同的传递参数方式会导致不同的结果,于是我就展开分析一下。

1. 错误示范

在开始正文之前,先来看段代码:

[[email protected]]# cat combine.go
func combine(i, k, n int, curr []int, rst *[][]int) {
    if len(curr) == k {
        *rst = append(*rst, curr)
        return
    }

    for j := i; j < n; j++ {
        curr = append(curr, j)
        combine(j+1, k, n, curr, rst)
        curr = curr[:len(curr)-1]
    }
}

这是一段求给定范围的组合数的代码,输入分别是:

当我用一组 case 来验证的时候,错误就出现了,你能想象这段代码的输出是什么吗?

[[email protected]]# cat main.go
func main() {
    var rst [][]int
    combine(0, 2, 3, []int{}, &rst)
    for i := 0; i < len(rst); i++ {
        fmt.Printf("%v\n", rst[i])
    }
}

如果你不想猜,那么直接看结论:

[[email protected]]# go run main.go
[0 2]
[0 2]
[1 2]

很显然,结果是错误的,正确的结果应该是:

[[email protected]]# go run main.go
go run *.go
[0 1]
[0 2]
[1 2]

2. 修改后的代码

于是我就对这段代码做了一个简单的修改:

[[email protected]]# cat combine.go
func combine(i, k, n int, curr []int, rst *[][]int) {
    if len(curr) == k {
        *rst = append(*rst, curr)
        return
    }

    for j := i; j < n; j++ {
        combine(j+1, k, n, append(curr, j), rst)
    }
}
图 1:代码改动

就这么稍微得调整就让代码正确了。那么这里的修改到底影响了什么,为什么这个改动可以影响最终的结果。

3. 代码分析

为了分析两种代码的情形,我首先做了一个简单的事情就是将代码的运行过程的数据地址打印出来:

[[email protected]]# cat combine.go
...
func combine(i, k, n int, curr []int, rst *[][]int) {
    fmt.Printf("i: %d, k: %d, n: %d, curr: %p\n", i, k, n, &curr)
...

然后对比两种不同情况的输出:

我想从这里应该可以发现问题一目了然了。第一个代码会导致循环中的每一个传递的 curr 都是同一个 slice,而很明显,即使把这个结果保存到了 rst 数组中,也会因为底层的数据被修改了从而导致最终的结果被改变。

那么为什么第二段代码却能每次都是不同的 slice 呢?难道说每次 append 都会创建一个新的 slice 吗?于是我做了一个新的实验。

4. append 试验

为了验证 append 的是否一定会创建新的 slice,于是我写了这段代码:

[[email protected]]# cat main.go
func main() {
    var (
        s1 []int
        s2 = make([]int, 3)
    )
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))
    s1 = append(s1, 1)
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))
    s1 = append(s1, 1)
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))

    fmt.Printf("s2: %p, len(s2): %d, cap(s2): %d\n", s2, len(s2), cap(s2))
    s2 = s2[:0]
    fmt.Printf("s2: %p, len(s2): %d, cap(s2): %d\n", s2, len(s2), cap(s2))
    s2 = append(s2, 1)
    fmt.Printf("s2: %p, len(s2): %d, cap(s2): %d\n", s2, len(s2), cap(s2))
}

然后看一下输出的结果:

[[email protected]]# go run main.go
go run *.go
s1: 0x0, len(s1): 0, cap(s1): 0
s1: 0xc00001a158, len(s1): 1, cap(s1): 1
s1: 0xc00001a180, len(s1): 2, cap(s1): 2
s2: 0xc00010c000, len(s2): 3, cap(s2): 3
s2: 0xc00010c000, len(s2): 0, cap(s2): 3
s2: 0xc00010c000, len(s2): 1, cap(s2): 3

OK,从这里的结果可以看出,append 只有在增加了 cap 的情况下才会创建出新的 slice,如果 cap 足够容纳 len 的时候,slice 还是那个 slice。那么问题又来了,难道 append 就只会每次只增加一个 cap?所以我又做了另外一个试验:

[[email protected]]# cat main.go
func main() {
    var (
        s1 []int
        s2 = make([]int, 100)
    )
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))
    s1 = append(s1, 1)
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))
    s1 = append(s1, s2...)
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))
    s1 = append(s1, 1)
    fmt.Printf("s1: %p, len(s1): %d, cap(s1): %d\n", s1, len(s1), cap(s1))
}

然后再运行一下看看效果:

[[email protected]]# go run main.go
go run *.go
s1: 0x0, len(s1): 0, cap(s1): 0
s1: 0xc00001a170, len(s1): 1, cap(s1): 1
s1: 0xc000100000, len(s1): 101, cap(s1): 112
s1: 0xc000100000, len(s1): 102, cap(s1): 112

很显然,这里的答案是不是的。所以,这样再反推一下,可能前面的第二段代码是不正确的,所以我又尝试了一组新的 case。

5. 第二段代码还是错了

基于这里对 append 的验证,所以我提高了 combine 的参数的数值,对 7 个数组进行长度为 7 的组合,其实就只有一个结果:[0, 1, 2, 3, 4, 5, 6],然后看下代码效果:

[[email protected]]# cat main.go
func main() {
    var rst [][]int
    combine(0, 7, 7, []int{}, &rst)
    for i := 0; i < len(rst); i++ {
        fmt.Printf("%v\n", rst[i])
    }
}

执行一下:

[[email protected]]# go run main.go
[0 1 2 3 4 6 6]

很好,证明我又错了。那怎么办,这里的代码要怎么改?我最后又只改了一点点地方:

[[email protected]]# cat main.go
... ... 
    for j := i; j < n; j++ {
        combine(j+1, k, n, append(curr, j)[:len(curr)+1:len(curr)+1], rst)
    }
... ...

然后又换了一个更加多数据的测试样例:

[[email protected]]# go run main.go 13 14
[0 1 2 3 4 5 6 7 8 9 10 11 12]
[0 1 2 3 4 5 6 7 8 9 10 11 13]
[0 1 2 3 4 5 6 7 8 9 10 12 13]
[0 1 2 3 4 5 6 7 8 9 11 12 13]
[0 1 2 3 4 5 6 7 8 10 11 12 13]
[0 1 2 3 4 5 6 7 9 10 11 12 13]
[0 1 2 3 4 5 6 8 9 10 11 12 13]
[0 1 2 3 4 5 7 8 9 10 11 12 13]
[0 1 2 3 4 6 7 8 9 10 11 12 13]
[0 1 2 3 5 6 7 8 9 10 11 12 13]
[0 1 2 4 5 6 7 8 9 10 11 12 13]
[0 1 3 4 5 6 7 8 9 10 11 12 13]
[0 2 3 4 5 6 7 8 9 10 11 12 13]
[1 2 3 4 5 6 7 8 9 10 11 12 13]

正如你所期望的,这次是正确了,没有问题。

6. 总结

那么为什么这么一改就正确了呢?回顾一下前面刚开始说为什么第二段代码在只有 3,4 个数据规模的时候是正确的,但是7,8个数据的时候就不对了?原因就是 append 在数据稍微多了一点点之后就与预先分配多一些 slot,从而让下一次 append 的时候不会再次分配新的 slice 内存。

那么我这里一改为:append(curr, j)[:len(curr)+1:len(curr)+1],其实就是强制每次 append 完之后的 slot 就是刚好多一个,从而造成每次 append 都会产生一个新的 slice,因为没有多余的 slot 用于放置新的元素,那么也就保证了最终的结果是我想要的。

本文通过一些简单的例子,介绍了 append 和 slice 的一个简单原理,就是 slice 底层是有两个内部数据的,分别是 length 和 cap,而 cap 和 length 的关系决定了在 append 的时候需不需要拷贝 slice 的内存。关于 slice 的这个基础知识,我在以前的文章:golang 中神奇的 slice 中介绍过,有兴趣可以稍微看一下。