Improving Text Classification Using the watsonx.ai Tuning Studio
One of the key promises of Large Language Models (LLMs) is the ability to rapidly provide value to a plethora of different use-cases by leveraging models that have been pre-trained on massive amounts of data. Although LLMs are powerful out-of-the-box, their performance can be improved using a framework called Prompt Tuning, which allows an LLM to become specialized for a specific task by learning relationships in a particular dataset. In this blog, I will showcase how the watsonx.ai platform can be used to quickly build classification models using LLMs, and how those results can be enhanced using the watsonx Tuning Studio.
The use-case and data
For this use-case, I used the Medical Transcriptions dataset from Kaggle. It contains notes that were written and transcribed by medical professionals after examining a patient, as well as the medical category associated with the diagnoses. I chose this dataset because it is representative of the kind of use-cases that are actively being developed in all sorts of industries: complicated and highly specialized data need to be analyzed and classified so that downstream users can consume them more rapidly. In this example, one can imagine a medical practitioner spending hours going over medical transcripts and manually classifying them into a set of discreet categories in order to inform other practitioners and provide
What if we could use an LLM to do that work for her instead?
I will focus on two columns, transcription
and medical_speciality
. An LLM will read the transcript for us, and then classify them into one of the four following categories:
- Surgery
- Cardiovascular / Pulmonary
- Orthopedic
- ENT / Otolaryngology.
I picked these categories because they present an interesting imbalance: 55% of the samples are in Surgery, 5% are in ENT / Otolaryngology, and 20% for the two other categories. Achieving strong results can be challenging on an imbalanced dataset because some categories with fewer datapoints than others will be harder to learn from, since they are less represented. As such, we will need to focus not just on accuracy but also on metrics such as precision and recall to make sure our model does not learn about patterns for the majority classes (in this case, Surgery) at the expense of the minority classes (especially ENT/ Otolaryngology),
To ensure the best comparison between the ML and LLM approaches, I divided my data into a training set and a testing set. The training set will be used to both train the ML model and to tune the LLM. The testing set will be then used to calculate model performance.
Note that since the data was highly unbalanced, I performed unstratified sampling to even the number of training examples from all classes. As always, the amount of samples to provide and their balancing is subject to experimentation, but I have found unstratified sampling to provide solid results for this use-case.
Let’s take a look at a typical transcript: there is a ton of medical jargon that probably looks incomprehensible for a non specialist (myself included). The associated label/class for this transcript is cardiovascular / pulmonary.
How can we leverage AI to help us? Before LLMs were popularized, we’d have to use a Machine Learning algorithm to learn the relationship between words and classes, which can be quite difficult due to high dimensionality (unique words being their own features), pre-processing challenges (understanding which stop words to filter) and perhaps most importantly, the inability to derive any sense of semantic understanding from text.
That being said, a good data scientist should be curious and explore all available options, and then compare what works best. So let’s actually see how well an ML model performs, and then compare the results to LLMs. I will use Auto-AI, which is available in watsonx.ai, to do this first pass. Let’s get started!
A first pass using AutoAI
An amazing feature that comes with the watsonx.ai platform is called AutoAI. It allows you to perform classification and regression tasks and has built in capabilities specifically for time series analysis, as well as fairness/bias.
In our case, we are going with a straightforward classification task. After selecting the input/explanatory variable (transcripts) and output/target variable (medical specialty), AutoAI will then run pipelines that evaluate different Machine Learning algorithms, such as ensemble models, SVMs, logistic regressions and so on. Each pipeline preprocesses the data (imputes missing values, encodes variables etc.), takes a model and iterates over different levels of feature engineering until the best score is attained.
The screenshot below shows that Pipeline 2 offered the best model: a Snap Random Forest classifier that achieves an accuracy of 0.658.
We can also unpack each Pipeline and take a look at what’s going on inside. We can see that the classification was performed by transforming the transcripts into word2vec vectors and taking their dimensions as transformed inputs. Other Pipelines (not shown here) also used methods such as TF-IDF, but the word2vec approach won out.
Finally, we can take a look at the confusion matrix created from the model classifications. Something that immediately jumps out is that the ENT / Otolaryngology predictions are way off! The model was not able to correctly predict a single case. Further, the performance for the other categories aside from Surgery are underwhelming. This suggests that in this Pipeline, the model was possibly overfitting on Surgery
values because it is the dominant class. Let’s now take a look a what we can do using LLMs instead.
Using the watsonx.ai Prompt Lab to build a classification prompt
Leveraging LLMs to solve use cases is an iterative process. Developing adequate prompts in watsonx.ai’s Prompt Lab is usually the first step that I take before doing any kind of downstream analysis with my data.
The Prompt Lab is structured so that you can experiment with prompts intuitively. It allows you to set the top level instructions — an explanation of the task at hand, then provides space to add examples, also called few-shots.
After trying out several different prompts, I settled on:
You are an expert at classifying transcripts into a set of discreet categories:
- Surgery
- Cardiovascular / Pulmonary
- Orthopedic
- ENT / OtolaryngologyRead the following transcript and classify it into one of the categories above. Do not answer with anything other than the above categories.
I then select a random sample transcript and paste into the UI, and click the Generate
button to get the predicted output, which turns out to be correct even without providing the model with any information! This is called zero-shot prompting. We can also provide examples for the LLM to calibrate its response: this is called few-shot prompting and it is pictured below. When dealing with a handful of classes, it can be a good idea to help the LLM by providing some examples associated with each class, but one has to be mindful of the LLM’s context window length: if the provided examples and the top level instructions are too long, the model won’t have enough space to use all of the provided information.
Watsonx.ai comes with a suite of models to choose from, ranging from popular open-source models such as Llama and Mixtral, to IBM’s own Granite Models. For this project, I have chosen the flan-t5-xl-3b
LLM which was trained and published by Google. I chose this model in particular because of its small size, and its encoder-decoder architecture makes it particularly apt at classification tasks. It is also one of the models that currently supports prompt tuning, which I will showcase later on.
Testing our prompt out on testing set
I showed that our prompt works well on a couple of examples picked at random. Naturally, the follow up question becomes: how well will it perform when applied to our entire test set?
Again, the watsonx.ai platform makes it very simple to see this for ourselves. Simply navigate to the floppy disk icon and click Save As
, which will then open a window where you can click the Notebook
button, which will convert your Prompt Lab work in a Jupyter Notebook.
We can now wrap our model inside a for loop
, iterating over each transcript in our database and get our model to evaluate it.
df['predicted_zero_shot'] = ''
for i in range(len(df)):
inp = df.iloc[i]['transcription']
prompt_input = f"""You are an expert at classifying transcripts into a set of discreet categories:
- Surgery
- Cardiovascular / Pulmonary
- Orthopedic
- ENT / Otolaryngology
Read the following transcript and classify it into one of the categories above. Do not answer anything other than the above categories.
Input: {inp}
Output:"""
df['predicted_zero_shot'].iloc[i] = model.generate_text(prompt=prompt_input, guardrails=False)
Let’s unpack the generated classes and evaluate the zero-shot prompt’s performance. Looking at the confusion matrix, we can already see some major improvements over the ML baseline: the accuracy has increased to 70%, despite a spurious prediction (Neurology) being made! We can also see a big difference in the quality of the predictions on the other classes.
Prompt Tuning in watsonx.ai
A powerful, yet simple way of adding extra power to an LLM is to use a method called prompt tuning. Prompt tuning allows an LLM to become really good at a specific task by feeding it some labeled examples so it can learn patterns in data. This is similar to traditional ML approaches to model training, a big advantage however is that sizable gains can be achieved with minimal training examples.
The watsonx.ai platform allows you to prompt-tune rapidly and efficiently. From the project space, navigate to New Asset > Work with models > Tune a foundation model with labeled data. All you need to do is provide your tuning set in JSON or JSONL format, the prompt you want to tune the model to, as well as the categories to classify. The platform also allows easy experimenting by setting the number of epochs to train on, the learning rate, and so on. As always, several experiments with different parameter values are needed to maximize the model’s accuracy.
For this scenario, I simply set the number of training epochs to its maximum value, 50. It takes an hour for the neural networks to learn from the data, but as the graph below illustrates, perhaps training the model on 20–25 epochs only would suffice.
With the tuned model ready, we can very easily access it in the watsonx.ai platform. The model will appear in the model repository in the Prompt Lab, and then can be accessed as a ModelInference
object in Python.
Our accuracy has now increased to 74%, not bad! We see big improvements for the majority class, but pretty big drops from ENT / Otolaryngology and Orthopedic. At every step of the way there are trade-offs, and those need to be determined by the project’s stakeholders.
Summary
In conclusion, here are some of the things I’ve learned when leveraging prompt tuning when performing classification. The upshot is that LLMs can provide value by outperforming sophisticated ML models and by being relatively simple to leverage. Let’s further break down my findings into a pros list and a cons list.
Pros:
- LLMs can outperform state-of-the-art ML models without seeing any data!
- LLMs provide quicker time to value as they can be leveraged out of the box with solid results, and be tuned in a short amount of time to generate even better results.
- Consequently, LLMs can be put to work without needing the in-depth knowledge of ML techniques to obtain satisfactory results.
Cons:
- As of now, estimating the confidence of an LLM classification is quite difficult. On the other hand, ML models provide confidence scores and as such are often seen as being more “trustworthy” than LLMs, since we can quantify and calibrate ML models using confidence scores and intervals.
- LLMs are stochastic in nature, which can make them hallucinate. They are also highly sensitive to prompts and thus require some experimentation, and the putting in place of guardrails to monitor their behavior.
I also hope this blog post conveys one of my guiding principles when it comes to working on data science use cases with clients: every step of a project involves working around assumptions and trade-offs. When working on a personal project, or for a client, always ask yourself: what is the outcome I am trying to achieve? who will the consumer of my outcome be? what are the risks and rewards associated with a given approach versus another?
These questions might seem trivial, but they will ultimately empower you to understand whether it is worth trading off a couple of percentage points on Surgery’s recall for a couple of percentage points for precision for Orthopedics, whether you want to spend 20 or 50 epochs fine tuning your LLM, and so on. Hopefully this post will help you along your way.
Happy prompting!
Thank you to my colleagues Courtney Branson and Drew Letvin for their feedback and advice.