Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Using GPT-4 for classification

GPT is a powerful new tool in the Data Science toolkit. Used correctly, it can increase productivity while decreasing the “drudgery” of boring tasks like data cleaning. I’ve been trying to find ways to integrate GPT into my data science workflow, and I thought it might be fun to use it with the latest Tidy Tuesday. (US Government Grant Opportunities). So in this blog post, I use GPT to classify US grant-funding agencies into categories using government agency names.

Load libraries and data

library(tidyverse)
library(DescTools)
library(keyring)
library(httr)
library(jsonlite)
library(numform)
library(scales)
library(snakecase)

grants <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2023/2023-10-03/grants.csv')
grant_opportunity_details <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2023/2023-10-03/grant_opportunity_details.csv')

# Load cached ChatGPT results from disk
load('categories.Rdata')
load('agencies_categorized.Rdata')

GPT function

I’ll define a function for querying ChatGPT (source). This function will use GPT-4 as the model, and the temperature will be set to 0. Temperature defines how stable the responses will be. A temperature closer to 1 is better for creative tasks where high variability is desired, like creative writing. A temperature of 0 is more “boring” but also more stable. This is desirable for classification tasks.

# ChatGPT function
chatGPT <- function(prompt,
                    modelName = "gpt-4",
                    temperature = 0,
                    apiKey = keyring::key_get('openai')) {

  response <- POST(
    url = "https://api.openai.com/v1/chat/completions",
    add_headers(Authorization = paste("Bearer", apiKey)),
    content_type_json(),
    encode = "json",
    body = list(
      model = modelName,
      temperature = temperature,
      messages = list(list(
        role = "user",
        content = prompt
      ))
    )
  )

  if(status_code(response)>200) {
    stop(content(response))
  }

  trimws(content(response)$choices[[1]]$message$content)
}

List of agencies

I’ll generate a distinct list of US government agencies in the dataset, and then format them as a comma-separated string. This will make it easy to append to a prompt string when querying GPT.

agencies <- sort(unique(grants$agency_name))
paste0(agencies,collapse = ', ') -> agencies_str

Using GPT for classification

I’ll split the classification task into two parts:

  1. Give GPT the full list of agency names, and ask it to come up with N categories

  2. For each of the N categories, give GPT the full list of agency names, and ask it to identify the agencies that belong in each category

1. Ask it to come up with categories

First I’ll ask GPT to define a fixed number of categories. Somewhat arbitrarily, I’ll ask for 10.

if(!exists('categories')){
  query <-
  paste0(
    "I'm going to present a list of US agencies.
    Your task is to identify the 10 best categories
    to represent them. The categories should be based
    on activities they're involved with, such as Defense
    and Military, Science and Technology, Arts and Culture,
    Health and Medicine, Education, International Aid,
    etc... I want you to give me the list of categories
    as a comma-separated list.

    ```", agencies_str, "```"
  )
  results <- chatGPT(query)

  categories <- trimws(strsplit(results, ",")[[1]])

  # Save to disk
  save(categories, file = 'categories.Rdata')
}

print(categories)

2. Ask it to classify agencies using agency names and categories

Next, using the agency names and categories generated above I’ll ask GPT to do classification. First, I’ll define a function that takes a category as input and runs a query against GPT asking it to put the agencies into the provided category, returning a list of agencies that it put in that category. Although I could ask GPT to classify agencies using all categories in a single query, I expect this would be less reliable than tackling each category separately. One of the trade-offs here is that I won’t get mutual exclusivity of the categories, since GPT won’t know how it’s already classified agencies. For this analysis, let’s assume I don’t want mutual exclusivity.

get_agencies_in_category <- function(category) {
  query <-paste0(
    "I'm going to present a comma-separated list of US agencies,
    delimited by triple backticks (```), and a single category to
    which some of the agencies belong, delimited by triple hashtags (###).
    Your task is to provide a comma-separated list of agency names that
    belong in the category provided.
    ```", agencies_str,  "```
    ###", category, "###"
  )
  results <- chatGPT(query)

  # Return as a vector
  trimws(strsplit(results, ",")[[1]])
}

Using the function above, I’ll query GPT for each of the categories, saving the results into a list called agencies_categorized.

if(!exists('agencies_categorized')){
  agencies_categorized <- list()
  for (category in categories){
    if(category != 'Other'){
      agencies_categorized[[category]] <- get_agencies_in_category(category)
    }
  }

  # Save to disk
  save(agencies_categorized, file = 'agencies_categorized.Rdata')
}

print(agencies_categorized)

Add categories to dataframe

Finally, I’ll add the category labels to the grant_opportunity_details dataframe. Since the categories are not mutually exclusive, I’ll encode them as boolean.

categories_snakecase <- snakecase::to_snake_case(categories)

for(i in seq(1, length(categories_snakecase))){
  grant_opportunity_details <- grant_opportunity_details %>%
    mutate(!!categories_snakecase[i] := agency_name %in% agencies_categorized[i][[1]])
}

Data analysis

Now that I have a list of agency categories, I can do some analysis. One question I might ask is how much total estimated funding is allocated to each of the agency categories?

estimated_total_funding <- data.frame('category' = c(), 'estimated_total_funding' = c())

for (c in categories_snakecase){
  estimated_total_funding <-
    estimated_total_funding %>%
      bind_rows(data.frame(
        'category' = c,
        'estimated_total_funding' = sum((grant_opportunity_details %>% filter(!!sym(c)))$estimated_total_program_funding, na.rm=T)
      ))
}

estimated_total_funding <-
  estimated_total_funding %>%
    mutate(funding_billions = numform::f_bills(estimated_total_funding, digits=-8)) %>%
    arrange(desc(estimated_total_funding))

estimated_total_funding %>%
  select(category, funding_billions)
estimated_total_funding %>%
  arrange(estimated_total_funding) -> estimated_total_funding_ordered

estimated_total_funding_ordered %>%
  arrange(estimated_total_funding) %>%
    ggplot(aes(x = fct_inorder(category), y = estimated_total_funding)) +
      geom_col() +
      coord_flip() +
      # Convert back from snake case to sentence case for better human readability
      scale_x_discrete(labels=unique(snakecase::to_sentence_case(estimated_total_funding_ordered$category))) +
      scale_y_continuous(labels = scales::unit_format(unit = "B", scale = 1e-9)) +
      labs(x = '', y = 'Funding in Billions', title = 'Funding by US Government Agency Category')
#ggsave("social-image.png", height=960, width=1344, units = 'px', dpi=150)