Making and Deploying an AI Web App in 2023 (Part 4)
Develop an App with Test-Driven Development
In this post we’ll have a look test-driven development. The idea is to first write some simple unit tests, where we describe how our code should work. Doing this forces you to think about what functions you actually need to write and, if you’re part of a team, helps ensure that everyone is one the same page. Only after having a few simple tests will we actually start writing code, and make it pass the unit tests.
For a smaller project that you do not which to maintain for a long time, having these simple unit tests is not a strong necessity. But be warned: your project will eventually break if used enough, and unit tests will likely be your only warning.
Setup
In our example we have a database with articles about COVID-19 (see Part 2). For our unit tests, we want to use the same data, but a lot smaller so tests run fast, and we have control of the outputs.
Let’s then think about what kind of simple examples we need.
The database has a articles
table and a sections
table.
Let’s open up the database and extract 2 articles, and a couple of sections from each of them.
It’s important here that you choose examples that are helpful.
In this case I chose 2 sections from each article: a Title and an Abstract, because the Titles won’t be indexed in
our use case (see Part 2 of this series for details).
I also chose sections that are different enough among themselves that I can make a query for each of them, and
verify that the result is the section I wanted.
Put these 2 files in tests/assets
:
articles.csv
:id,source,published,publication,authors,title,tags,design,size,sample,method,reference,entry 071hvhme,WHO,2020-01-01 00:00:00,MSMR,"Kebisek, Julianna; Forrest, Lanna J; Maule, Alexis L; Steelman, Ryan A; Ambrose, John F","Special report: Prevalence of selected underlying health conditions among active component Army service members with coronavirus disease 2019, 11 February-6 April 2020",COVID-19,0,873,"COVID-19 has been a reportable condition for the Department of Defense since 5 February 2020, and, as of the morning of 6 April, a total of 873 cases were reported to the Disease Reporting System internet from Army installations.",,https://doi.org/,2020-06-07 08as6hga,WHO,2020-01-01 00:00:00,Int. j. sports med,"Stokes, Keith A; Jones, Ben; Bennett, Mark; Close, Graeme L; Gill, Nicholas; Hull, James H; Kasper, Andreas M; Kemp, Simon P T; Mellalieu, Stephen D; Peirce, Nicholas; Stewart, Bob; Wall, Benjamin T; West, Stephen W; Cross, Matthew",Returning to Play after Prolonged Training Restrictions in Professional Collision Sports,COVID-19,0,,The COVID-19 pandemic in 2020 has resulted in widespread training disruption in many sports.,,https://doi.org/,2020-06-08
sections.csv
:id,article,tags,design,name,text,labels 1749694,071hvhme,COVID-19,0,TITLE,"Special report: Prevalence of selected underlying health conditions among active component Army service members with coronavirus disease 2019, 11 February-6 April 2020",FRAGMENT 1749697,071hvhme,COVID-19,0,ABSTRACT,"COVID-19 has been a reportable condition for the Department of Defense since 5 February 2020, and, as of the morning of 6 April, a total of 873 cases were reported to the Disease Reporting System internet from Army installations.",SAMPLE_SIZE 1867855,08as6hga,COVID-19,0,TITLE,Returning to Play after Prolonged Training Restrictions in Professional Collision Sports,FRAGMENT 1867856,08as6hga,COVID-19,0,ABSTRACT,The COVID-19 pandemic in 2020 has resulted in widespread training disruption in many sports.,SAMPLE_SIZE
We also need to setup some fixtures to access these files. The original data is a SQLite file with 2 tables, so we need to write a fixture that reads the CSV files and generates a SQLite file.
Create a tests/conftest.py
file:
import sqlite3
from pathlib import Path
from tempfile import NamedTemporaryFile
import pandas as pd
import pytest
@pytest.fixture(scope="module")
def assets_path():
return Path(__file__).resolve().parent / "assets"
@pytest.fixture(scope="module")
def articles_database(assets_path) -> Path:
with NamedTemporaryFile() as f:
conn = sqlite3.connect(f.name)
# load CSVs
articles = pd.read_csv(assets_path / "articles.csv", sep=",")
sections = pd.read_csv(assets_path / "sections.csv", sep=",")
# write CSVs to DB
articles.to_sql("articles", conn, index=False)
sections.to_sql("sections", conn, index=False)
yield Path(f.name)
You can then use these fixtures for all your tests.
Writing tests
We can already make a skeleton of what we want, before actually starting programming.
Let’s then make a ai_web_app/main.py
file, where we will in the future
implement the functions that we used in Part 2:
from pathlib import Path
from txtai.embeddings import Embeddings
def index_embeddings(database: Path) -> Embeddings:
raise NotImplementedError
def search(embeddings: Embeddings, database: Path, query: str, topn: int = 5):
raise NotImplementedError
If you use an IDE, it will start complaining that txtai
isn’t installed.
That’s because we haven’t yet added it to our package dependencies.
So open up pyproject.toml
and add the latest version as a dependency:
dependencies = [
"txtai[similarity]==5.3.0",
]
Now we’re ready to implement our tests. At this stage we just need some very simple tests, to make sure the indexing and querying are doing what we expect.
Let’s create a new file tests/test_main.py
and add this:
import pytest
from ai_web_app.main import index_embeddings, search
@pytest.fixture
def example_embeddings(articles_database):
return index_embeddings(articles_database)
def test_index_embeddings_count(example_embeddings):
assert example_embeddings.count() == 2
def test_search(example_embeddings, articles_database):
res_sports = search(example_embeddings, articles_database, "sports", 1)
assert res_sports[0].id == "08as6hga"
res_military = search(example_embeddings, articles_database, "army", 1)
assert res_military[0].id == "071hvhme"
There are 2 tests there:
test_index_embeddings_count
tests the size of the generated index. We expect the index to have only 2 entries, as sections of typeTITLE
are not indexed.test_search
tests the correction of the queries. We have one example about sports and another about the army, and in this text we just make sure that the correct section is returned for each query.
Testing
To run the tests use
hatch run cov
This command will fail with NotImplementedError
, as the functions we’re testing aren’t yet implemented.
If you’re following along and got a different error, check the link below for the full code.
Implementing Functions
We now need to implement the functions index_embeddings
and search
that pass the tests above.
We take most of the code from Part 2, but change a couple of things to
make it more production-ready.
The biggest changes we do in this case are: (1) moving away from using print
towards using a proper
logging library (loguru in this case), and (2) removing the dependency
on the pandas
library, as it wasn’t really necessary and in this case will only add overhead to our functions.
The first thing to do is to add loguru
to the dependencies section of our pyproject.toml
.
We can then add the imports we need for index_embeddings
:
import sqlite3
from pathlib import Path
import regex as re
from loguru import logger
from txtai.embeddings import Embeddings
from txtai.pipeline import Tokenizer
and the function itself is copied from Part 2, but using loguru
instead of print
def index_embeddings(database: Path) -> Embeddings:
def stream():
# Connection to database file
db = sqlite3.connect(database)
cur = db.cursor()
# Select tagged sentences without a NLP label.
# NLP labels are set for non-informative sentences.
cur.execute(
"SELECT Id, Name, Text FROM sections "
"WHERE (labels is null or labels NOT IN ('FRAGMENT', 'QUESTION')) "
"AND tags is not null"
)
count = 0
for row in cur:
# Unpack row
uid, name, text = row
# Only process certain document sections
if not name or not re.search(
r"background|(?<!.*?results.*?)discussion|introduction|reference",
name.lower(),
):
# Tokenize text
tokens = Tokenizer.tokenize(text)
document = (uid, tokens, None)
count += 1
if count % 1000 == 0:
logger.debug(f"Streamed {count} documents")
# Skip documents with no tokens parsed
if tokens:
yield document
logger.info(f"Iterated over {count} total rows")
# Free database resources
db.close()
# BM25 + fastText vectors
embeddings = Embeddings(
{
"method": "sentence-transformers",
"path": "all-MiniLM-L6-v2",
"scoring": "bm25",
}
)
embeddings.index(stream())
return embeddings
For the search
function, in the notebook version we’re outputting a pandas
DataFrame.
In the notebook that’s a nice way to visualize the output of the function, but actually we
don’t really the dataframe and can use something simpler.
Furthermore, as this function will be exposed as an API, it would be nice to have a schema
of what the output will actually look like.
For that we can use Python’s dataclasses
(an alternative would be pydantic’s BaseModel
, with more functionality
but requires one more library to be installed).
We can have a look at what fields we want to fetch from our database and return as a result from search
and make our own Result
class:
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Result:
id: str
title: str
published: datetime
reference: str
text: str
score: float
We then change our original search
function to
def search(
embeddings: Embeddings, database: Path, query: str, topn: int = 5
) -> List[Result]:
db = sqlite3.connect(database)
cur = db.cursor()
results: List[Result] = []
for uid, score in embeddings.search(query, topn):
cur.execute("SELECT article, text FROM sections WHERE id = ?", [uid])
uid, text = cur.fetchone()
cur.execute(
"SELECT Title, Published, Reference from articles where id = ?", [uid]
)
res = cur.fetchone()
results.append(
Result(
id=uid,
title=res[0],
published=res[1],
reference=res[2],
text=text,
score=score,
)
)
db.close()
return results
Running hatch run cov
should now show the tests as passing.
To continue this tutorial, go to Part 5.
For comments or questions, use the Reddit discussion or reach out to me directly via email.