Golang常用Mock方式总结
1.通用Mock方式
类似Java Mockito。testify和mockery结合使用,testify是一个golang测试框架,主要有assert、mock和test suite三个特性,mockery利用testify的mock来生成mock的代码。
testify包下载:
go get github.com/stretchr/testify
mockery安装:
go get github.com/vektra/mockery/.../
mockery会根据定义的interface生成对应的mock struct。
示例代码
common/etcd/client.go
common/etcd/mocks/EtcdClient.go
sql-driver/rds/config/loader/remote_configuration_loader_test.go
1. 生成mock strcut
命令行执行go generate
或者使用goland直接生成,此处会自动创建mocks目录,以及对应的mock struct文件。
生成的代码如下所示:
// Code generated by mockery v1.0.0. DO NOT EDIT.
package mocks
import (
etcd "github.com/huaweicloud/devcloud-go/common/etcd"
mock "github.com/stretchr/testify/mock"
clientv3 "go.etcd.io/etcd/client/v3"
)
// EtcdClient is an autogenerated mock type for the EtcdClient type
type EtcdClient struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *EtcdClient) Close() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Del provides a mock function with given fields: key
func (_m *EtcdClient) Del(key string) (int64, error) {
ret := _m.Called(key)
var r0 int64
if rf, ok := ret.Get(0).(func(string) int64); ok {
r0 = rf(key)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(key)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Get provides a mock function with given fields: key
func (_m *EtcdClient) Get(key string) (string, error) {
ret := _m.Called(key)
var r0 string
if rf, ok := ret.Get(0).(func(string) string); ok {
r0 = rf(key)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(key)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// List provides a mock function with given fields: prefix
func (_m *EtcdClient) List(prefix string) ([]*etcd.KeyValue, error) {
ret := _m.Called(prefix)
var r0 []*etcd.KeyValue
if rf, ok := ret.Get(0).(func(string) []*etcd.KeyValue); ok {
r0 = rf(prefix)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*etcd.KeyValue)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(prefix)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Put provides a mock function with given fields: key, value
func (_m *EtcdClient) Put(key string, value string) (string, error) {
ret := _m.Called(key, value)
var r0 string
if rf, ok := ret.Get(0).(func(string, string) string); ok {
r0 = rf(key, value)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(key, value)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Watch provides a mock function with given fields: prefix, startIndex, onEvent
func (_m *EtcdClient) Watch(prefix string, startIndex int64, onEvent func(*clientv3.Event)) {
_m.Called(prefix, startIndex, onEvent)
}
2. 业务代码调用EtcdClient
此处定义了一个RemoteConfigurationLoader,其成员etcdClient的类型为上述定义的EtcdClient interface,在Get方法中调用EtcdClient的Get方法。
type RemoteConfigurationLoader struct {
etcdClient etcd.EtcdClient
dataSourceKey string
routerKey string
activeKey string
listeners []config.RouterConfigurationListener
}
func (l *RemoteConfigurationLoader) Get()(string, string, string) {
dataSourceConfig, err := l.etcdClient.Get(l.dataSourceKey)
if err != nil {
return "", "", ""
}
routerConfig, err := l.etcdClient.Get(l.routerKey)
if err != nil {
log.Printf("ERROR: get remote routerConfig failed, %v", err)
return "", "", ""
}
active, err := l.etcdClient.Get(l.activeKey)
if err != nil {
log.Printf("ERROR: get remote active failed, %v", err)
return "", "", ""
}
return dataSourceConfig, routerConfig, active
}
3. 测试代码编写
编写对RemoteConfigurationLoader的Get()方法的测试代码,对EtcdClient的Get方法进行mock。
1 import (
2 "fmt"
3 "testing"
4
5 "github.com/huaweicloud/devcloud-go/common/etcd/mocks"
6 "github.com/stretchr/testify/assert"
7 )
8
9 func TestRemoteConfigurationLoader_Get(t *testing.T) {
10 mockClient := &mocks.EtcdClient{}
11 loader := &RemoteConfigurationLoader{
12 dataSourceKey: "datasourceKey",
13 routerKey: "routerKey",
14 activeKey: "activeKey",
15 etcdClient: mockClient,
16 }
17 mockClient.On("Get", loader.dataSourceKey).Return("data", nil).Once()
18 mockClient.On("Get", loader.routerKey).Return("router", nil).Once()
19 mockClient.On("Get", loader.activeKey).Return("active", nil).Once()
20 datasource, router, active := loader.Get()
21 assert.Equal(t, "data", datasource)
22 assert.Equal(t, "router", router)
23 assert.Equal(t, "active", active)
24 }
其中17-19行代码就是在mock我们想要的数据。mockClient调用On方法,首先传入要mock的方法名字,然后传入方法参数,此处是利用golang的反射来实现的。Return方法中传入想要mock的返回数据,最后调用Once()方法表示此方法只执行一次。
4. 参考文档
- https://segmentfault.com/a/1190000016897506
- https://www.xuanzhangjiong.top/2019/10/12/mockery%E4%BB%8B%E7%BB%8D%E5%8F%8A%E4%BD%BF%E7%94%A8/
- https://github.com/vektra/mockery
2. Mysql Mock
2.1 mock mysql server(推荐)
利用 github.com/dolthub/go-mysql-server,go-mysql-server基于Mysql语法,解析标准sql,它可以在内存中启动一个mysql server。
安装:
go get github.com/dolthub/go-mysql-server
示例代码:
package main
import (
"fmt"
"time"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/auth"
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/information_schema"
)
const (
user = "user"
passwd = "pass"
address = "localhost"
port = "13306"
dbName = "test"
tableName = "pets"
)
func main() {
engine := sqle.NewDefault(
sql.NewDatabaseProvider(
createTestDatabase(),
information_schema.NewInformationSchemaDatabase(),
))
config := server.Config{
Protocol: "tcp",
Address: fmt.Sprintf("%s:%s", address, port),
Auth: auth.NewNativeSingle(user, passwd, auth.AllPermissions),
}
s, err := server.NewDefaultServer(config, engine)
if err != nil {
panic(err)
}
go func() {
s.Start()
}()
fmt.Println("mysql-server started!")
<- make(chan interface{})
}
func createTestDatabase() *memory.Database {
db := memory.NewDatabase(dbName)
table := memory.NewTable(tableName, sql.Schema{
{Name: "name", Type: sql.Text, Nullable: false, Source: tableName},
{Name: "email", Type: sql.Text, Nullable: false, Source: tableName},
{Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName},
{Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: tableName},
})
db.AddTable(tableName, table)
ctx := sql.NewEmptyContext()
rows := []sql.Row{
sql.NewRow("John Doe", "jasonkay@doe.com", []string{"555-555-555"}, time.Now()),
sql.NewRow("John Doe", "johnalt@doe.com", []string{}, time.Now()),
sql.NewRow("Jane Doe", "jane@doe.com", []string{}, time.Now()),
sql.NewRow("Evil Bob", "jasonkay@gmail.com", []string{"555-666-555", "666-666-666"}, time.Now()),
}
for _, row := range rows {
_ = table.Insert(ctx, row)
}
return db
2.2 mock sql driver
使用 DATA-DOG/go-sqlmock,该包实现了go sdk sql/driver的接口,本质上是一个mock驱动
安装:
go get github.com/DATA-DOG/go-sqlmock
示例代码
原生sql代码使用 示例代码见:https://github.com/DATA-DOG/go-sqlmock
编写ut时利用定义的globalMock做dml前置操作,具体使用方法见官方文档。
2.1 结合beego orm
var (
globalOrm orm.Ormer
once sync.Once
mockOnce sync.Once
globalMockOrm orm.Ormer
GlobalMock sqlmock.Sqlmock
)
func GetOrmer() orm.Ormer {
if utils.GetenvOrDefault("isTest", "") == "true" {
mockOnce.Do(func() {
var db *sql.DB
db, GlobalMock, _ = sqlmock.New()
GlobalMock.ExpectPrepare("SELECT TIMEDIFF")
GlobalMock.ExpectPrepare("SELECT ENGINE")
globalMockOrm, _ = orm.NewOrmWithDB("mysql", "default", db)
})
return globalMockOrm
}
once.Do(func() {
// override the default value(1000) to return all records when setting no limit
orm.DefaultRowsLimit = -1
globalOrm = orm.NewOrm()
})
return globalOrm
}
测试代码
type Book struct {
Id int64 `gorm:"column:id"`
Title string `gorm:"column:title"`
}
func TestSqlMockBeegoOrm(t *testing.T) {
os.Setenv("isTest", "true")
ormer := driver_test.GetOrmer()
GlobalMock.ExpectQuery("SELECT").WillReturnRows(
sqlmock.NewRows([]string{"id", "title"}).
AddRow(1, "one"))
book := &Book{Id:1}
err := ormer.Read(book)
assert.Nil(t, err)
assert.Equal(t, "one", book.Title)
}
2.2.2 结合gorm
var (
globalGormDB *gorm.DB
globalMockGormDB *gorm.DB
globalMock sqlmock.Sqlmock
once sync.Once
mockOnce sync.Once
)
func GetGormDB() *gorm.DB {
if utils.GetenvOrDefault("isTest", "") == "true" {
mockOnce.Do(func() {
var db *sql.DB
db, globalMock, _ = sqlmock.New()
globalMockGormDB, _ = gorm.Open(mysql.New(mysql.Config{
Conn: db,
SkipInitializeWithVersion: true,
}), &gorm.Config{})
})
return globalMockGormDB
}
once.Do(func() {
globalGormDB, _ = gorm.Open(mysql.Open("user:pass@tcp(127.0.0.1:3306)/dbname"), &gorm.Config{})
})
return globalGormDB
}
测试代码
func TestSqlMockGorm(t *testing.T) {
os.Setenv("isTest", "true")
gormDB := driver_test.GetGormDB()
driver_test.GlobalMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `books`")).WillReturnRows(
sqlmock.NewRows([]string{"id", "title"}).
AddRow(1, "one").
AddRow(2, "two"))
var books []Book
err := gormDB.Find(&books).Error
assert.Nil(t, err)
assert.Equal(t, 2, len(books))
assert.Equal(t, int64(1), books[0].Id)
assert.Equal(t, "one", books[0].Title)
assert.Equal(t, int64(2), books[1].Id)
assert.Equal(t, "two", books[1].Title)
}
2.3 使用优劣
个人感觉mock mysql server最方便,对代码侵入较少,测试代码也会更少,测试范围会更广。
mock sql driver只适合对业务层mock数据库操作,测试业务代码;当使用复杂sql时,需要将orm的链式操作等转为一个复杂sql语句用于mock,需要编写大量测试代码。
而mock mysql server在此基础上还能测试数据库操作的代码是否正确。
2.4 参考文档
- 使用纯Go实现的Mysql数据库
- https://pkg.go.dev/github.com/dolthub/go-mysql-server#section-readme
- https://zhuanlan.zhihu.com/p/249313716
- https://github.com/DATA-DOG/go-sqlmock
- https://blog.csdn.net/weixin_44294408/article/details/120698482
3. Redis Mock
使用github.com/alicebob/miniredis/v2 ,miniredis可以在内存中启动一个redis server,支持大部分redis命令,具体支持情况见github readme
示例代码:
redis/devspore_client_test.go
import (
"context"
"testing"
"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
)
func TestMockRedis(t *testing.T) {
server, _ := miniredis.Run()
client := redis.NewClient(&redis.Options{Addr: server.Addr()})
ctx := context.Background()
client.Set(ctx, "test", "val", 0)
res := client.Get(ctx, "test")
assert.Nil(t, res.Err())
assert.Equal(t, "val", res.Val())
server.Close()
}
4. Ginkgo测试框架
Ginkgo是一个BDD(Behavior Driven Development)风格的go测试框架,与Gomega配合使用,在需要写大量单测时,特别是需要一些通用代码时,Ginkgo可以使用BeforeEach和AfterEach将每个用例的通用步骤提取出来,会让代码看起来很清爽。
具体使用方法见 https://ke-chain.github.io/ginkgodoc/
示例代码
devcloud-go/sql-driver/mysql/devspore_driver_test.go 使用Ginkgo框架编写driver的CRUD测试代码,结合mock mysql server使用。
import (
"database/sql"
"fmt"
"testing"
"github.com/huaweicloud/devcloud-go/sql-driver/rds/config"
"github.com/huaweicloud/devcloud-go/sql-driver/rds/datasource"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/auth"
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/server"
mocksql "github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/information_schema"
)
func TestGinkgoSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "mysql")
}
var _ = Describe("CRUD", func() {
var (
devsporeDB *sql.DB
masterDB *sql.DB
err error
activeNode *datasource.NodeDataSource
)
go startMockServer()
BeforeEach(func() {
devsporeDB, err = sql.Open("devspore_mysql", "../rds/resources/driver_test_config.yaml")
Expect(err).NotTo(HaveOccurred())
activeNode, err = initDB()
Expect(err).NotTo(HaveOccurred())
masterDB, err = sql.Open("mysql", activeNode.MasterDataSource.Dsn)
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
Expect(devsporeDB.Close()).NotTo(HaveOccurred())
Expect(masterDB.Close()).NotTo(HaveOccurred())
})
It("Test Query", func() {
var (
val string
flag bool
)
err = devsporeDB.QueryRow("SELECT val FROM foo WHERE id=?", id1).Scan(&val)
Expect(err).NotTo(HaveOccurred())
for _, slave := range activeNode.SlavesDatasource {
if slave.Name == val {
flag = true
}
}
Expect(flag).To(Equal(true))
})
It("Test Insert", func() {
var val string
_, err = devsporeDB.Exec(`INSERT INTO foo (id, val) VALUES (?, ?)`, id2, "insert")
Expect(err).NotTo(HaveOccurred())
err = masterDB.QueryRow("SELECT val FROM foo WHERE id=?", id2).Scan(&val)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("insert"))
})
It("Test Update", func() {
var val string
_, err = devsporeDB.Exec(`UPDATE foo set val=? where id=?`, "update", id1)
Expect(err).NotTo(HaveOccurred())
err = masterDB.QueryRow("SELECT val FROM foo WHERE id=?", id1).Scan(&val)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("update"))
})
It("Test Delete", func() {
var val string
_, err = devsporeDB.Exec(`DELETE FROM foo where id=?`, id1)
Expect(err).NotTo(HaveOccurred())
err = masterDB.QueryRow("SELECT val FROM foo WHERE id=?", id1).Scan(&val)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("sql: no rows in result set"))
})
})
const (
user = "root"
passwd = "root"
address = "localhost"
port = "13306"
)
func startMockServer() {
engine := sqle.NewDefault(
mocksql.NewDatabaseProvider(
memory.NewDatabase("ds0"),
memory.NewDatabase("ds1"),
memory.NewDatabase("ds0-slave0"),
memory.NewDatabase("ds0-slave1"),
memory.NewDatabase("ds1-slave0"),
memory.NewDatabase("ds1-slave1"),
information_schema.NewInformationSchemaDatabase(),
))
config := server.Config{
Protocol: "tcp",
Address: fmt.Sprintf("%s:%s", address, port),
Auth: auth.NewNativeSingle(user, passwd, auth.AllPermissions),
}
s, err := server.NewDefaultServer(config, engine)
if err != nil {
panic(err)
}
go func() {
s.Start()
}()
fmt.Println("mysql-server started!")
}
5. 代码参考
示例代码标有文件路径的均来自devcloud-go项目,具体可看https://github.com/huaweicloud/devcloud-go。本文是在本人开发devcloud-go过程中积累而成,各位看官可以移步devcloud-go项目点个star~
- 点赞
- 收藏
- 关注作者
评论(0)