Postgres for RAG

10/2024 | Aman Azad

Embrace tradition, embrace Postgres maximalism. pg/acc.

Specialized vector DB solutions like Pinecone are great, and I got started doing AI RAG solutions with them because like usual, they win on developer experience. I remember the path to first RAG query was so easy I ended up spending more time on auth than RAG.

But there's risks to tacking on more and more vendor solutions for things that you really should be doing with the tools already available. I already have a BAA signed with AWS, I already have all my customers data isolated and clamped down with IAM roles, and network security groups. I've got to talk to Enterprise Sales at Pinecone to go through the same process again.

So let's do everything in Postgres, specifically in AWS RDS. Let's get pgvector set up locally, and on RDS, and go through the code full stack.

#pg/acc

pgvector

Local Setup with Docker

As a starting point let's get pgvector set up locally assuming we've already been running a Postgres instance via docker.

I've usually had a simple docker-compose.yml file that would've originally looked like this:

services:
  postgres:
    image: "postgres:16.4"
    container_name: casely-postgres
    ports:
      - 5432:5432
    environment:
      POSTGRES_DB: casely
      POSTGRES_USER: user
      POSTGRES_PASSWORD: password
    volumes:
      - casely-db:/var/lib/postgresql/data

volumes:
  casely-db:

The default docker postgres image postgres:16.4 does not come with with the pgvector extension installed. We'll go ahead and define a custom Dockerfile that'll build off the base postgres image and set up the pgvector extension:

services:
  postgres:
    build:
      context: .
      dockerfile: ./packages/db/Dockerfile.db # path to the custom Dockerfile
    container_name: casely-postgres
    ports:
      - 5432:5432
    environment:
      POSTGRES_DB: casely
      POSTGRES_USER: user
      POSTGRES_PASSWORD: password
    volumes:
      - casely-db:/var/lib/postgresql/data

volumes:
  casely-db:

And our Dockerfile.db would look like this:

# Use the official PostgreSQL image as the base image
FROM postgres:16.4

# Install pgvector extension
RUN apt-get update && apt-get install -y postgresql-16-pgvector

# Clean up package manager cache
RUN apt-get clean && rm -rf /var/lib/apt/lists/*

# Default command to run PostgreSQL
CMD ["postgres"]

Great. Now our docker-compose up -d command will rebuild itself with pgvector installed, and since we're referencing the same volume, our data will persist.

RDS Setup

RDS makes this easy, as long as we instantiate any RDS instance we'll have the pgvector extension pre-installed.

So let's do a few things here - let's set up RDS with terraform, and add in some network security groups here to not expose our data to the public internet (called "Private Link" in Pinecone land, only available on Enterprise).

Starting off first, lets define a security group that'll only allow traffic from our ECS containers, and our local IP.

There's some variables we'll use to safely set up both our RDS instance and security group:

resource "aws_security_group" "rds_sg" {
  name        = "rds_sg"
  description = "Security group for RDS instance"
  vpc_id      = data.aws_vpc.default.id

  ingress {
    description = "PostgreSQL access from my IP"
    from_port   = 5432
    to_port     = 5432
    protocol    = "tcp"
    cidr_blocks = ["${var.my_ip}/32"]
  }

  ingress {
    description     = "PostgreSQL access from ECS"
    from_port       = 5432
    to_port         = 5432
    protocol        = "tcp"
    security_groups = [aws_security_group.ecs_sg.id]
  }

  ingress {
    description = "SSH access from My IP"
    from_port   = 22
    to_port     = 22
    protocol    = "tcp"
    cidr_blocks = ["${var.my_ip}/32"]
  }

  egress {
    description = "Allow all outbound traffic"
    from_port   = 0
    to_port     = 0
    protocol    = "-1"
    cidr_blocks = ["0.0.0.0/0"]
  }

  tags = {
    Name = "rds_sg",
  }
}

And our actual RDS instance setup is here:

resource "aws_db_subnet_group" "rds_subnet_group" {
  name       = "rds_subnet_group"
  subnet_ids = var.multi_az_public_subnet_ids

  tags = {
    Name = "rds_subnet_group"
  }
}

resource "aws_db_instance" "casely_qa_db" {
  allocated_storage      = 20
  engine                 = "postgres"
  engine_version         = "16.4"
  instance_class         = "db.t4g.micro"
  db_name                = "casely_qa"
  username               = "casely"
  password               = var.db_password
  parameter_group_name   = "default.postgres16"
  publicly_accessible    = true
  db_subnet_group_name   = aws_db_subnet_group.rds_subnet_group.name
  vpc_security_group_ids = [aws_security_group.rds_sg.id]
  skip_final_snapshot    = true

  tags = {
    Name = "casely_qa_db"
  }
}

Migrations, Setting up Tables

Flyway is a great tool to manage migrations. I'm a quick and lazy type of developer, no DB branches, no migration rollbacks, just forward progress and assuming null values everywhere. Here's how I had flyway set up --

Install flyway via brew:

brew install flyway

Then back at the root of my project I run this command:

p run migrate:local

Which is essentially a package.json script that runs the following:

# For running local DB migrations:
# For running migrations on RDS, replace the database URL with the RDS endpoint, username, and password

flyway -locations=filesystem:packages/db/migrations -url=jdbc:postgresql://localhost:5432/casely -user=user -password=password migrate

I've placed all my migration SQL files in this director package/db/migrations. And the migration file to set up pgvector and an embedding table looks like this:

# V1__setup_pgvector.sql

-- Install pgvector
create extension if not exists vector with schema public;

-- Set up table, 1024 dimensions for cohere's embed model
-- https://docs.cohere.com/docs/models#embed
create table law_embedding (
  id serial primary key,
  created_at timestamptz default current_timestamp,
  updated_at timestamptz,
  deleted_at timestamptz,
  law_text text,
  embedding vector (1024),
  name text,
  tag_state varchar(255),
  tag_federal boolean,
  tag_act varchar(255)
);

-- HNSW indexes can be create immediately on an empty table
create index on law_embedding using hnsw (embedding vector_ip_ops);

And there it is - a fully vectorized database, with "metadata" columns for data I need specific for my use case.

All of this is setup nice and safely in my AWS VPC, and only accessible through my IP address, and the apps I have running in ECS.

Also, I've got all the benefits of postgres, as well as a very easy way to set up custom indices, try different embeddings, search algorithms, etc.

Vectorizing and Uploading a CSV

I've got a simple script that reads in a CSV, processes the text into an embedding, and does a very simple insert into our DB using the pg node package.

There's a lot and not a lot here. In summary, I'm using AWS Bedrock's JavaScript SDK to call the Cohere Embed model. There's some helper packages to read in the CSV, and a simple pg client to insert into the DB. I use a DB transaction to make sure I'm not borking the database if something goes wrong.

Note, and this is Cohere's SDK specific, but these embeddings are embedded with the input_type: "search_document" which tells Cohere I'll want to search against these records later.

const { Client } = require("pg");
const fs = require("fs");
const fastcsv = require("fast-csv");
const path = require("path");
const { finished } = require("stream/promises");
const {
  BedrockRuntimeClient,
  InvokeModelCommand,
} = require("@aws-sdk/client-bedrock-runtime");

// Dotenv here loads the env file storing by AWS/DB credentials
require("dotenv").config({
  path: path.resolve(__dirname, ".env.local"),
});

// Setup AWS Bedrock client
const BEDROCK_EMBED_MODEL = "cohere.embed-english-v3";
const bedrock = new BedrockRuntimeClient({
  region: process.env.AWS_REGION,
  credentials: {
    accessKeyId: process.env.AWS_ACCESS_KEY_ID,
    secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY,
  },
});

// Set up database connection using environment variables
const client = new Client({
  connectionString: process.env.DATABASE_URL,
});

// --
async function embed(text) {
  if (!text) {
    throw new Error("Text is required for embedding");
  }

  // Reference - AWS Bedrock -> Providers -> Cohere -> Embed English
  try {
    const input = {
      modelId: BEDROCK_EMBED_MODEL,
      body: JSON.stringify({
        input_type: "search_document",
        embedding_types: ["float"],
        texts: [text],
      }),
    };

    // Invoke the Cohere model via AWS Bedrock
    const command = new InvokeModelCommand(input);
    const response = await bedrock.send(command);
    console.log("response", response);

    // Convert the Uint8Array to a string
    const decoder = new TextDecoder("utf-8");
    const responseBody = decoder.decode(response.body);

    // Extract the embeddings from the response
    const data = JSON.parse(responseBody);
    console.log("data", data);

    const embedding = data?.embeddings?.float?.[0];
    console.log("embedding", embedding.length, embedding?.[0]);

    if (!embedding) {
      throw new Error("No embeddings found in response");
    }

    return `[${String(embedding)}]`;
  } catch (error) {
    console.error("Error calling Cohere embedding model via Bedrock: ", error);
    throw new Error("Error calling Cohere embedding model via Bedrock");
  }
}

async function uploadLaws() {
  try {
    // Connect to the database
    await client.connect();
    console.log("Connected to database.");

    // Read the CSV file
    const laws = [];
    const csvStream = fs
      .createReadStream("embeddings/casely-laws.csv")
      .pipe(fastcsv.parse({ headers: true }));

    // Push each row to the laws array as they are read
    csvStream.on("data", (row) => {
      // Assuming the 'embedding' column is a serialized JSON array, parse it to an array
      console.log("row", row);
      laws.push(row);
    });

    // Await the end of the CSV stream processing
    await finished(csvStream);
    console.log(`Parsed ${laws.length} rows from CSV.`);

    for (const law of laws) {
      // Embed the law name using the AI SDK
      const embedding = await embed(law.law_text);
      if (!embedding) {
        console.error("Could not embed law:", law);
        continue;
      }

      // Query DB for matching law name
      const existingLaw = await client.query(
        `
            SELECT * FROM law_embedding WHERE name = $1
          `,
        [law.name],
      );
      if (existingLaw.rows.length > 0) {
        console.error("Law already exists in database:", existingLaw.rows[0]);
        continue;
      }

      // Insert the law into the database
      await client.query(
        `
            INSERT INTO law_embedding (name, law_text, embedding, tag_state, tag_federal, tag_act)
            VALUES ($1, $2, $3, $4, $5, $6)
          `,
        [
          law.name,
          law.law_text,
          embedding, // This should match the VECTOR column type, ensure correct format.
          law.tag_state,
          law.tag_federal === "TRUE",
          law.tag_act,
        ],
      );
    }

    console.log(`Uploaded ${laws.length} laws successfully.`);
  } catch (err) {
    // In case of any errors, rollback the transaction
    console.error("Error occurred:", err);
    await client.query("ROLLBACK");
  } finally {
    // Disconnect from the database
    await client.end();
    console.log("Disconnected from database.");
  }
}

// Run the upload function
uploadLaws();

NextJS Code to Embed Queries and Fetch Results

And now to tie it all up, we can mostly do another SQL query here to fetch the results and complete our RAG.

One call out here with the sorting - the similarity score I have here is a cosine similarity score that's normalized 0 to 1 vs -1 to 1. Makes for easier rendering on the UI.

// Embed the input
const inputEmbedding = await embed(input);

// Cosine similarity search
const result = await query<LawEmbedding & { similarity_score: number }>(
  `select id,
    law_text,
    created_at,
    updated_at,
    deleted_at,
    name,
    tag_state,
    tag_federal,
    tag_act,
    (1 + (embedding <=> $1)) / 2 as similarity_score
   from law_embedding
   where deleted_at is null
   order by similarity_score desc
   limit $2 
  `,
  [`[${String(inputEmbedding)}]`, numberOfResults],
);
const ragItems = result?.rows || [];