在flask中编写单元测试

为什么要编写单元测试以及单元测试的重要性,这里就不再强调了,可以搜索相关文章或者查看github上多数有影响力的开源软件都是标配单元测试的,你只需要记住,单元测试是非常非常重要的就OK了。

在flask中编写单元测试

python的单元测试框架众多,如何挑选合适自己的单元测试框架?开发人员可能会对这个问题有不同的见解,答案是,你用着哪个顺手就用哪一个。

下面就flask+sqlalchemy+mysql的常见组合来说明一下,如何对在编写web接口的过程中编写单元测试。

常规的python项目结构如下:

我们的测试脚本通常坐落于tests文件夹下。

准备工作

我们的flask项目ORM使用的是sqlalchemy,数据库是mysql。由于进行单元测试需要插入和修改数据,因此不能在生产的数据库中进行操作。这里我们引入一个第三方的库flask_testing.

flask_testing提供了一个重要的功能,就是对sqlalchemy的支持,我们需要在TestCase中引入一个叫create_app的function并在该函数中返回app,用于接下来的测试。

1
2
def create_app(self):
return app

显然这里的app不能用flask项目中的app,否则连接的就是项目运行中的数据库了。这里我们需要新创建一个空的数据库专门用于单元测试,这个数据库不会有任何的表和数据,表的创建由测试框架自动完成,测试完成后会自动销毁。

创建app的代码如下:

1
2
3
4
5
6
7
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8mb4".format(
"***", "***", "***", 3306, "unittest")
app.config['TESTING'] = True
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
app.register_blueprint(v1_recharge.bp, url_prefix='/v1/recharge')
db.init_app(app)

创建数据库结构

我们需要在测试开始前创建好表结构,这里我们使用sqlalchemy的create_all方法,它会自动帮我们创建好表结果,并检测你的model类型定义的是不是正确。

这里我们把建表的方法放到了setUpClass中,没有放到setUp方法中,因为setUp和tearDown方法会在每个test_func中运行,如果放到setUp中,就会每次都重写建表和销毁表,有些冗余。

放在setUpClass中会引起这个问题:

1
No application found. Either work inside a view function or push an application context.

问题的原因是我们的setUpClass方法是类方法,缺少application context导致的不能正常运行,解决方案就把application context加上去:

1
with app.app_context():

同样的,销毁表的操作,我们也放到tearDownClass方法而不涉及tearDown方法中。

完整的代码示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8mb4".format(
"***", "***", "***", 3306, "unittest")
app.config['TESTING'] = True
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
app.register_blueprint(v1_recharge.bp, url_prefix='/v1/recharge')
db.init_app(app)
class RechargeTest(TestCase):
member_id = "201810241615509296"
def create_app(self):
return app
@classmethod
def setUpClass(cls):
with app.app_context():
db.create_all()
m = Member(member_id=cls.member_id)
db.session.add(m)
db.session.commit()
@classmethod
def tearDownClass(cls):
with app.app_context():
db.session.remove()
db.drop_all()

这里的member_id是我的一个测试数据,用于后边的测试用例测试。

组织测试数据

我们可以把测试数据库放到一个方法中以备复用,这里我拿一个插入数据库的方法作为示例:

1
2
3
4
5
6
7
8
9
def recharge(self):
res = self.client.post("/v1/recharge/recharge", json={
"member_id": self.member_id,
"payment": 1,
"amount": random.choice(range(100)),
"bonus": random.choice(range(20)),
"date": "2019-02-01"
})
return res

这个方法的作用是给某个会员充值,调用的就是我们充值的接口。单独拿出来是因为我们后边会测试获取充值记录等接口,直接在测试获取充值的代码中调用即可。

编写测试用例

测试,不仅仅是对接口的测试,还包括对于内部方法的测试。比如,我们的充值模型中包含如下一个方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class RechargeModel(db.Model):
__tablename__ = 'member_recharge'
id = db.Column(db.Integer, primary_key=True)
member_id = db.Column(db.String(40))
order_no = db.Column(db.String(20))
payment = db.Column(db.Integer)
amount = db.Column(db.DECIMAL)
bonus = db.Column(db.DECIMAL)
status = db.Column(db.Integer)
date = db.Column(db.DATETIME, default=datetime.now)
...
@classmethod
def get_recharge_data(cls, member_id, start_date="2001-01-01", end_date="2099-01-01", page=1, perpage=20):
"""
获取充值记录
参数:
member_id: 会员号
start_date: 查询开始时间
end_date: 查询结束时间
page: 页码
perpage: 每页条数
"""
return RechargeModel.query.filter(
RechargeModel.member_id == member_id,
RechargeModel.date >= start_date,
RechargeModel.date <= end_date
).paginate(
int(page), int(perpage), False)
...

获取充值记录的方法是一个类方法,这个方法我们应该用单元测试覆盖到:

1
2
3
4
def test_get_recharge_data(self):
"""测试获取充值记录"""
self.recharge()
self.assertTrue(bool(RechargeModel.get_recharge_data(self.member_id)))

这个示例非常简单,就是看我们能不能获取到插入到数据库中的充值记录。

运行测试

一般我们会用python -m tests.test_xx 的方式在开发中进行测试:

1
2
3
4
5
python3 -m tests.test_recharge
----------------------------------------------------------------------
Ran 7 tests in 1.399s
OK

也可以单独对某个方法进行测试:

1
2
3
4
5
6
python3 -m unittest tests.test_recharge.RechargeTest.test_get_recharge_data
.
----------------------------------------------------------------------
Ran 1 test in 0.772s
OK

很多人(我也是)在一开始会用诸如python3 tests/test_xxx.py的方式运行,这里会出错,这个后边会讲。

这只是开发的测试,更推荐的一种测试方法是把你的代码打包,然后通过编写setup.py来进行测试。这里我写一个简单的setup脚本用来说明。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from collections import OrderedDict
from setuptools import setup, find_packages
setup(
name='my_project',
version="1.0.0",
url='https://www.palletsprojects.com/p/flask/',
project_urls=OrderedDict((
('Documentation', 'http://flask.pocoo.org/docs/'),
('Code', 'https://github.com/pallets/flask'),
('Issue tracker', 'https://github.com/pallets/flask/issues'),
)),
license='BSD',
author='Backend Team',
maintainer='Pallets team',
description='xxx project',
packages=find_packages(),
include_package_data=True,
zip_safe=False,
platforms='any',
python_requires='>=3.5',
test_suite="tests"
)

编写好脚本,使用python3 setup.py test命令进行测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
...
test_check_member_id (tests.test_recharge.RechargeTest)
测试会员检测接口 ... ok
test_consume (tests.test_recharge.RechargeTest)
添加会员消费记录 ... ok
test_form_date (tests.test_recharge.RechargeTest)
测试输出数据 ... ok
test_get_consume (tests.test_recharge.RechargeTest)
获取消费记录 ... ok
test_get_recharge_date (tests.test_recharge.RechargeTest)
测试获取充值记录 ... ok
test_getrecharge (tests.test_recharge.RechargeTest)
获取充值记录 ... ok
test_recharge (tests.test_recharge.RechargeTest)
充值测试 ... ok
...

这是github上更推荐的一种方式,具体可以看这里的讨论

覆盖率

编写完测试,我们可以运行一下coverage来看一下我们的代码覆盖率:

1
coverage run -m pytest tests/test_recharge.py

会输出一个类似的提示:

1
2
3
4
5
6
7
8
9
10
11
ests/test_recharge.py::RechargeTest::test_recharge
tests/test_recharge.py::RechargeTest::test_recharge
tests/test_recharge.py::RechargeTest::test_recharge
tests/test_recharge.py::RechargeTest::test_recharge
tests/test_recharge.py::RechargeTest::test_recharge
tests/test_recharge.py::RechargeTest::test_recharge
/home/kevin/.local/lib/python3.6/site-packages/flask_sqlalchemy/__init__.py:157: SADeprecationWarning: Use .persist_selectable
info = getattr(mapper.mapped_table, 'info', {})
-- Docs: https://docs.pytest.org/en/latest/warnings.html
===================== 7 passed, 36 warnings in 3.50 seconds =======================

然后使用coverage report查看报告,report默认会加载全部的代码,包括第三方库,我们可以编写一个配置文件来过滤掉不用测试的文件,或者指定一个源代码文件夹:

1
2
3
4
5
[run]
source =
commom
model
v1

这里我们只要指定源代码路径就可以了。运行coverage report来查看覆盖率:

1
2
3
4
5
6
7
8
9
10
11
12
...
v1/v1_recharge/api/consume.py 24 5 79%
v1/v1_recharge/api/getconsume.py 24 6 75%
v1/v1_recharge/api/getrecharges.py 25 5 80%
v1/v1_recharge/api/recharge.py 27 4 85%
v1/v1_times_card/api/buy.py 8 8 0%
v1/v1_times_card/api/delete.py 8 8 0%
v1/v1_times_card/api/edit.py 8 8 0%
v1/v1_times_card/api/list.py 8 8 0%
v1/v1_times_card/api/show.py 8 8 0%
v1/v1_times_card/api/using_records.py 8 8 0%
...

这样源代码的覆盖率就一目了然了。

另外,你可以使用report html来输出html格式的报告:

html

路径引入的问题

我们上面说过,在常规的python项目结构中,如果使用python3 tests/test_xx.py的方式运行单元测试,会报module找不到的问题:

1
ModuleNotFoundError: No module named 'app'

解决的方案一般有一下两种:

  1. hack点的方式:
    from future import absolute_import
    import os
    import sys
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(file), ‘..’)))

    import mypkg

  2. 创建setup.py,以安装包的形式进行测试。

这里我们强烈推荐以第2种方式进行。