Golang常用Mock方式总结

举报
junzhi 发表于 2021/12/09 20:50:51 2021/12/09
【摘要】 本文总结了golang常见的几种mock方式,包括基于interface的通用mock方式,针对mysql,redis的mock方式。

1.通用Mock方式

类似Java Mockito。testifymockery结合使用,testify是一个golang测试框架,主要有assert、mock和test suite三个特性,mockery利用testify的mock来生成mock的代码。

testify包下载:
go get github.com/stretchr/testify
mockery安装:
go get github.com/vektra/mockery/.../

mockery会根据定义的interface生成对应的mock struct。

示例代码

common/etcd/client.go
common/etcd/mocks/EtcdClient.go
sql-driver/rds/config/loader/remote_configuration_loader_test.go

1. 生成mock strcut

image.png
命令行执行go generate 或者使用goland直接生成,此处会自动创建mocks目录,以及对应的mock struct文件。
image.png
生成的代码如下所示:

// Code generated by mockery v1.0.0. DO NOT EDIT.

package mocks

import (
	etcd "github.com/huaweicloud/devcloud-go/common/etcd"
	mock "github.com/stretchr/testify/mock"
	clientv3 "go.etcd.io/etcd/client/v3"
)

// EtcdClient is an autogenerated mock type for the EtcdClient type
type EtcdClient struct {
	mock.Mock
}

// Close provides a mock function with given fields:
func (_m *EtcdClient) Close() error {
	ret := _m.Called()

	var r0 error
	if rf, ok := ret.Get(0).(func() error); ok {
		r0 = rf()
	} else {
		r0 = ret.Error(0)
	}

	return r0
}

// Del provides a mock function with given fields: key
func (_m *EtcdClient) Del(key string) (int64, error) {
	ret := _m.Called(key)

	var r0 int64
	if rf, ok := ret.Get(0).(func(string) int64); ok {
		r0 = rf(key)
	} else {
		r0 = ret.Get(0).(int64)
	}

	var r1 error
	if rf, ok := ret.Get(1).(func(string) error); ok {
		r1 = rf(key)
	} else {
		r1 = ret.Error(1)
	}

	return r0, r1
}

// Get provides a mock function with given fields: key
func (_m *EtcdClient) Get(key string) (string, error) {
	ret := _m.Called(key)

	var r0 string
	if rf, ok := ret.Get(0).(func(string) string); ok {
		r0 = rf(key)
	} else {
		r0 = ret.Get(0).(string)
	}

	var r1 error
	if rf, ok := ret.Get(1).(func(string) error); ok {
		r1 = rf(key)
	} else {
		r1 = ret.Error(1)
	}

	return r0, r1
}

// List provides a mock function with given fields: prefix
func (_m *EtcdClient) List(prefix string) ([]*etcd.KeyValue, error) {
	ret := _m.Called(prefix)

	var r0 []*etcd.KeyValue
	if rf, ok := ret.Get(0).(func(string) []*etcd.KeyValue); ok {
		r0 = rf(prefix)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).([]*etcd.KeyValue)
		}
	}

	var r1 error
	if rf, ok := ret.Get(1).(func(string) error); ok {
		r1 = rf(prefix)
	} else {
		r1 = ret.Error(1)
	}

	return r0, r1
}

// Put provides a mock function with given fields: key, value
func (_m *EtcdClient) Put(key string, value string) (string, error) {
	ret := _m.Called(key, value)

	var r0 string
	if rf, ok := ret.Get(0).(func(string, string) string); ok {
		r0 = rf(key, value)
	} else {
		r0 = ret.Get(0).(string)
	}

	var r1 error
	if rf, ok := ret.Get(1).(func(string, string) error); ok {
		r1 = rf(key, value)
	} else {
		r1 = ret.Error(1)
	}

	return r0, r1
}

// Watch provides a mock function with given fields: prefix, startIndex, onEvent
func (_m *EtcdClient) Watch(prefix string, startIndex int64, onEvent func(*clientv3.Event)) {
	_m.Called(prefix, startIndex, onEvent)
}

2. 业务代码调用EtcdClient

此处定义了一个RemoteConfigurationLoader,其成员etcdClient的类型为上述定义的EtcdClient interface,在Get方法中调用EtcdClient的Get方法。

type RemoteConfigurationLoader struct {
	etcdClient    etcd.EtcdClient
	dataSourceKey string
	routerKey     string
	activeKey     string
	listeners     []config.RouterConfigurationListener
}

func (l *RemoteConfigurationLoader) Get()(string, string, string) {
	dataSourceConfig, err := l.etcdClient.Get(l.dataSourceKey)
	if err != nil {
		return "", "", ""
	}

	routerConfig, err := l.etcdClient.Get(l.routerKey)
	if err != nil {
		log.Printf("ERROR: get remote routerConfig failed, %v", err)
		return "", "", ""
	}

	active, err := l.etcdClient.Get(l.activeKey)
	if err != nil {
		log.Printf("ERROR: get remote active failed, %v", err)
		return "", "", ""
	}
	return dataSourceConfig, routerConfig, active
}

3. 测试代码编写

编写对RemoteConfigurationLoader的Get()方法的测试代码,对EtcdClient的Get方法进行mock。

1 import (
2	"fmt"
3	"testing"
4
5	"github.com/huaweicloud/devcloud-go/common/etcd/mocks"
6	"github.com/stretchr/testify/assert"
7 )
8
9 func TestRemoteConfigurationLoader_Get(t *testing.T) {
10 	mockClient := &mocks.EtcdClient{}
11 	loader := &RemoteConfigurationLoader{
12 		dataSourceKey: "datasourceKey",
13 		routerKey:     "routerKey",
14 		activeKey:     "activeKey",
15 		etcdClient:    mockClient,
16 	}
17 	mockClient.On("Get", loader.dataSourceKey).Return("data", nil).Once()
18 	mockClient.On("Get", loader.routerKey).Return("router", nil).Once()
19 	mockClient.On("Get", loader.activeKey).Return("active", nil).Once()
20 	datasource, router, active := loader.Get()
21 	assert.Equal(t, "data", datasource)
22 	assert.Equal(t, "router", router)
23 	assert.Equal(t, "active", active)
24 }

其中17-19行代码就是在mock我们想要的数据。mockClient调用On方法,首先传入要mock的方法名字,然后传入方法参数,此处是利用golang的反射来实现的。Return方法中传入想要mock的返回数据,最后调用Once()方法表示此方法只执行一次。

4. 参考文档

2. Mysql Mock

2.1 mock mysql server(推荐)

利用 github.com/dolthub/go-mysql-server,go-mysql-server基于Mysql语法,解析标准sql,它可以在内存中启动一个mysql server。
安装:

go get github.com/dolthub/go-mysql-server

示例代码:

package main

import (
	"fmt"
	"time"

	sqle "github.com/dolthub/go-mysql-server"
	"github.com/dolthub/go-mysql-server/auth"
	"github.com/dolthub/go-mysql-server/memory"
	"github.com/dolthub/go-mysql-server/server"
	"github.com/dolthub/go-mysql-server/sql"
	"github.com/dolthub/go-mysql-server/sql/information_schema"
)

const (
	user      = "user"
	passwd    = "pass"
	address   = "localhost"
	port      = "13306"
	dbName    = "test"
	tableName = "pets"
)

func main()  {
	engine := sqle.NewDefault(
		sql.NewDatabaseProvider(
			createTestDatabase(),
			information_schema.NewInformationSchemaDatabase(),
		))

	config := server.Config{
		Protocol: "tcp",
		Address:  fmt.Sprintf("%s:%s", address, port),
		Auth:     auth.NewNativeSingle(user, passwd, auth.AllPermissions),
	}

	s, err := server.NewDefaultServer(config, engine)
	if err != nil {
		panic(err)
	}

	go func() {
		s.Start()
	}()

	fmt.Println("mysql-server started!")

	<- make(chan interface{})
}

func createTestDatabase() *memory.Database {
	db := memory.NewDatabase(dbName)
	table := memory.NewTable(tableName, sql.Schema{
		{Name: "name", Type: sql.Text, Nullable: false, Source: tableName},
		{Name: "email", Type: sql.Text, Nullable: false, Source: tableName},
		{Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName},
		{Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: tableName},
	})

	db.AddTable(tableName, table)
	ctx := sql.NewEmptyContext()

	rows := []sql.Row{
		sql.NewRow("John Doe", "jasonkay@doe.com", []string{"555-555-555"}, time.Now()),
		sql.NewRow("John Doe", "johnalt@doe.com", []string{}, time.Now()),
		sql.NewRow("Jane Doe", "jane@doe.com", []string{}, time.Now()),
		sql.NewRow("Evil Bob", "jasonkay@gmail.com", []string{"555-666-555", "666-666-666"}, time.Now()),
	}

	for _, row := range rows {
		_ = table.Insert(ctx, row)
	}

	return db

2.2 mock sql driver

使用 DATA-DOG/go-sqlmock,该包实现了go sdk sql/driver的接口,本质上是一个mock驱动
安装:

go get github.com/DATA-DOG/go-sqlmock

示例代码

原生sql代码使用 示例代码见:https://github.com/DATA-DOG/go-sqlmock
编写ut时利用定义的globalMock做dml前置操作,具体使用方法见官方文档。

2.1 结合beego orm

var (
	globalOrm     orm.Ormer
	once          sync.Once
	mockOnce      sync.Once
	globalMockOrm orm.Ormer
	GlobalMock    sqlmock.Sqlmock
)

func GetOrmer() orm.Ormer {
	if utils.GetenvOrDefault("isTest", "") == "true" {
		mockOnce.Do(func() {
			var db *sql.DB
			db, GlobalMock, _ = sqlmock.New()
			GlobalMock.ExpectPrepare("SELECT TIMEDIFF")
			GlobalMock.ExpectPrepare("SELECT ENGINE")
			globalMockOrm, _ = orm.NewOrmWithDB("mysql", "default", db)
		})
		return globalMockOrm
	}
	once.Do(func() {
		// override the default value(1000) to return all records when setting no limit
		orm.DefaultRowsLimit = -1
		globalOrm = orm.NewOrm()
	})
	return globalOrm
}

测试代码

type Book struct {
	Id    int64  `gorm:"column:id"`
	Title string `gorm:"column:title"`
}
func TestSqlMockBeegoOrm(t *testing.T) {
	os.Setenv("isTest", "true")
	ormer := driver_test.GetOrmer()
	GlobalMock.ExpectQuery("SELECT").WillReturnRows(
		sqlmock.NewRows([]string{"id", "title"}).
			AddRow(1, "one"))

	book := &Book{Id:1}
	err := ormer.Read(book)
	assert.Nil(t, err)
	assert.Equal(t, "one", book.Title)
}

2.2.2 结合gorm

var (
	globalGormDB     *gorm.DB
	globalMockGormDB *gorm.DB
	globalMock       sqlmock.Sqlmock
	once             sync.Once
	mockOnce         sync.Once
)

func GetGormDB() *gorm.DB {
	if utils.GetenvOrDefault("isTest", "") == "true" {
		mockOnce.Do(func() {
			var db *sql.DB
			db, globalMock, _ = sqlmock.New()
			globalMockGormDB, _ = gorm.Open(mysql.New(mysql.Config{
				Conn: db,
                SkipInitializeWithVersion: true,
			}), &gorm.Config{})
		})
		return globalMockGormDB
	}
	once.Do(func() {
		globalGormDB, _ = gorm.Open(mysql.Open("user:pass@tcp(127.0.0.1:3306)/dbname"), &gorm.Config{})
	})
	return globalGormDB
}

测试代码

func TestSqlMockGorm(t *testing.T) {
	os.Setenv("isTest", "true")
	gormDB := driver_test.GetGormDB()
	driver_test.GlobalMock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `books`")).WillReturnRows(
		sqlmock.NewRows([]string{"id", "title"}).
			AddRow(1, "one").
			AddRow(2, "two"))

	var books []Book
	err := gormDB.Find(&books).Error
	assert.Nil(t, err)
	assert.Equal(t, 2, len(books))
	assert.Equal(t, int64(1), books[0].Id)
	assert.Equal(t, "one", books[0].Title)
	assert.Equal(t, int64(2), books[1].Id)
	assert.Equal(t, "two", books[1].Title)
}

2.3 使用优劣

个人感觉mock mysql server最方便,对代码侵入较少,测试代码也会更少,测试范围会更广。
mock sql driver只适合对业务层mock数据库操作,测试业务代码;当使用复杂sql时,需要将orm的链式操作等转为一个复杂sql语句用于mock,需要编写大量测试代码。
而mock mysql server在此基础上还能测试数据库操作的代码是否正确。

2.4 参考文档

3. Redis Mock

使用github.com/alicebob/miniredis/v2 ,miniredis可以在内存中启动一个redis server,支持大部分redis命令,具体支持情况见github readme
示例代码:
redis/devspore_client_test.go

import (
	"context"
	"testing"

	"github.com/alicebob/miniredis/v2"
	"github.com/go-redis/redis/v8"
	"github.com/stretchr/testify/assert"
)

func TestMockRedis(t *testing.T) {
	server, _ := miniredis.Run()
	client := redis.NewClient(&redis.Options{Addr: server.Addr()})

	ctx := context.Background()
	client.Set(ctx, "test", "val", 0)
	res := client.Get(ctx, "test")
	assert.Nil(t, res.Err())
	assert.Equal(t, "val", res.Val())
	server.Close()
}

4. Ginkgo测试框架

Ginkgo是一个BDD(Behavior Driven Development)风格的go测试框架,与Gomega配合使用,在需要写大量单测时,特别是需要一些通用代码时,Ginkgo可以使用BeforeEach和AfterEach将每个用例的通用步骤提取出来,会让代码看起来很清爽。
具体使用方法见 https://ke-chain.github.io/ginkgodoc/

示例代码

devcloud-go/sql-driver/mysql/devspore_driver_test.go 使用Ginkgo框架编写driver的CRUD测试代码,结合mock mysql server使用。

import (
	"database/sql"
	"fmt"
	"testing"

	"github.com/huaweicloud/devcloud-go/sql-driver/rds/config"
	"github.com/huaweicloud/devcloud-go/sql-driver/rds/datasource"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"

	sqle "github.com/dolthub/go-mysql-server"
	"github.com/dolthub/go-mysql-server/auth"
	"github.com/dolthub/go-mysql-server/memory"
	"github.com/dolthub/go-mysql-server/server"
	mocksql "github.com/dolthub/go-mysql-server/sql"
	"github.com/dolthub/go-mysql-server/sql/information_schema"
)

func TestGinkgoSuite(t *testing.T) {
	RegisterFailHandler(Fail)
	RunSpecs(t, "mysql")
}

var _ = Describe("CRUD", func() {
	var (
		devsporeDB *sql.DB
		masterDB   *sql.DB
		err        error
		activeNode *datasource.NodeDataSource
	)
	go startMockServer()

	BeforeEach(func() {
		devsporeDB, err = sql.Open("devspore_mysql", "../rds/resources/driver_test_config.yaml")
		Expect(err).NotTo(HaveOccurred())
		activeNode, err = initDB()
		Expect(err).NotTo(HaveOccurred())
		masterDB, err = sql.Open("mysql", activeNode.MasterDataSource.Dsn)
		Expect(err).NotTo(HaveOccurred())
	})

	AfterEach(func() {
		Expect(devsporeDB.Close()).NotTo(HaveOccurred())
		Expect(masterDB.Close()).NotTo(HaveOccurred())
	})

	It("Test Query", func() {
		var (
			val  string
			flag bool
		)
		err = devsporeDB.QueryRow("SELECT val FROM foo WHERE id=?", id1).Scan(&val)
		Expect(err).NotTo(HaveOccurred())
		for _, slave := range activeNode.SlavesDatasource {
			if slave.Name == val {
				flag = true
			}
		}
		Expect(flag).To(Equal(true))
	})

	It("Test Insert", func() {
		var val string
		_, err = devsporeDB.Exec(`INSERT INTO foo (id, val) VALUES (?, ?)`, id2, "insert")
		Expect(err).NotTo(HaveOccurred())
		err = masterDB.QueryRow("SELECT val FROM foo WHERE id=?", id2).Scan(&val)
		Expect(err).NotTo(HaveOccurred())
		Expect(val).To(Equal("insert"))
	})

	It("Test Update", func() {
		var val string
		_, err = devsporeDB.Exec(`UPDATE foo set val=? where id=?`, "update", id1)
		Expect(err).NotTo(HaveOccurred())
		err = masterDB.QueryRow("SELECT val FROM foo WHERE id=?", id1).Scan(&val)
		Expect(err).NotTo(HaveOccurred())
		Expect(val).To(Equal("update"))
	})

	It("Test Delete", func() {
		var val string
		_, err = devsporeDB.Exec(`DELETE FROM foo where id=?`, id1)
		Expect(err).NotTo(HaveOccurred())
		err = masterDB.QueryRow("SELECT val FROM foo WHERE id=?", id1).Scan(&val)
		Expect(err).To(HaveOccurred())
		Expect(err.Error()).To(Equal("sql: no rows in result set"))
	})
})


const (
	user      = "root"
	passwd    = "root"
	address   = "localhost"
	port      = "13306"
)

func startMockServer()  {
	engine := sqle.NewDefault(
		mocksql.NewDatabaseProvider(
			memory.NewDatabase("ds0"),
			memory.NewDatabase("ds1"),
			memory.NewDatabase("ds0-slave0"),
			memory.NewDatabase("ds0-slave1"),
			memory.NewDatabase("ds1-slave0"),
			memory.NewDatabase("ds1-slave1"),
			information_schema.NewInformationSchemaDatabase(),
		))

	config := server.Config{
		Protocol: "tcp",
		Address:  fmt.Sprintf("%s:%s", address, port),
		Auth:     auth.NewNativeSingle(user, passwd, auth.AllPermissions),
	}
	s, err := server.NewDefaultServer(config, engine)
	if err != nil {
		panic(err)
	}
	go func() {
		s.Start()
	}()
	fmt.Println("mysql-server started!")
}

5. 代码参考

示例代码标有文件路径的均来自devcloud-go项目,具体可看https://github.com/huaweicloud/devcloud-go。本文是在本人开发devcloud-go过程中积累而成,各位看官可以移步devcloud-go项目点个star~

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。