Using pre-trained models and embeddings

🎯 Learning objectives


By the end of this session, participants should be able to:

  • Understand what embeddings are and why they’re useful
  • Gain a high-level understanding of transfer learning.
  • Distinguish between training a model from scratch vs. using a pre-trained one
  • Use transfer learning for your own tasks.
  • Generate and visualize embeddings from images and text
  • Use pre-trained models (e.g., MiniLM, DenseNet, CLIP)
  • Identify ML tasks in your own domain that could benefit from embeddings

Why Pretrained Models?

  • Training from scratch = lots of data + compute
  • Pretrained models already know a lot from massive datasets
  • You can:
    • Extract features (embeddings)
    • Fine-tune on your domain-specific task

What are embeddings?

  • Embeddings are numerical vector representations of complex objects
  • They enable us to calculate distances between concepts
    • Similar inputs \(\rightarrow\) nearby vectors

Embeddings are commonly used for

  • Similarity search
  • Clustering
  • Visualization
  • Downstream ML models

Image embeddings

Introduction to computer vision


  • Computer vision refers to understanding images/videos, usually using ML/AI.
  • In the last decade this field has been dominated by deep learning. We will explore image classification and object detection.

Introduction to computer vision


  • image classification: is this a cat or a dog?
  • object localization: where is the cat in this image?
  • object detection: What are the various objects in the image?
  • instance segmentation: What are the shapes of these various objects in the image?
  • and much more…

Pre-trained models


  • In practice, very few people train an entire CNN from scratch because it requires a large dataset, powerful computers, and a huge amount of human effort to train the model.
  • Instead, a common practice is to download a pre-trained model and fine tune it for your task. This is called transfer learning.
  • Transfer learning is one of the most common techniques used in the context of computer vision and natural language processing.
  • It refers to using a model already trained on one task as a starting point for learning to perform another task.

Pre-trained models out-of-the-box


  • Let’s first apply one of these pre-trained models to our own problem right out of the box.

Pre-trained models out-of-the-box


  • We can easily download famous models using the torchvision.models module. All models are available with pre-trained weights (based on ImageNet’s 224 x 224 images)
  • We used a pre-trained model vgg16 which is trained on the ImageNet data.
  • We preprocess the given image.
  • We get prediction from this pre-trained model on a given image along with prediction probabilities.
  • For a given image, this model will spit out one of the 1000 classes from ImageNet.

Pre-trained models out-of-the-box

  • Let’s predict labels with associated probabilities for unseen images

                         Class  Probability score
                     tiger cat              0.353
              tabby, tabby cat              0.207
               lynx, catamount              0.050
Pembroke, Pembroke Welsh corgi              0.046
--------------------------------------------------------------

                                     Class  Probability score
         cheetah, chetah, Acinonyx jubatus              0.983
                  leopard, Panthera pardus              0.012
jaguar, panther, Panthera onca, Felis onca              0.004
       snow leopard, ounce, Panthera uncia              0.001
--------------------------------------------------------------

                                   Class  Probability score
                                 macaque              0.714
patas, hussar monkey, Erythrocebus patas              0.122
      proboscis monkey, Nasalis larvatus              0.098
                   guenon, guenon monkey              0.017
--------------------------------------------------------------

                        Class  Probability score
Walker hound, Walker foxhound              0.580
             English foxhound              0.091
                  EntleBucher              0.080
                       beagle              0.065
--------------------------------------------------------------

Pre-trained models out-of-the-box


  • We got these predictions without “doing the ML ourselves”.
  • We are using pre-trained vgg16 model which is available in torchvision.
    • torchvision has many such pre-trained models available that have been very successful across a wide range of tasks: AlexNet, VGG, ResNet, Inception, MobileNet, etc.
  • Many of these models have been pre-trained on famous datasets like ImageNet.
  • So if we use them out-of-the-box, they will give us one of the ImageNet classes as classification.

Pre-trained models out-of-the-box


  • Let’s try some images which are unlikely to be there in ImageNet.
  • It’s not doing very well here because ImageNet doesn’t have proper classes for these images.

         Class  Probability score
cucumber, cuke              0.146
         plate              0.117
     guacamole              0.099
  Granny Smith              0.091
--------------------------------------------------------------

                                      Class  Probability score
                                        fig              0.637
                                pomegranate              0.193
grocery store, grocery, food market, market              0.041
                                      crate              0.023
--------------------------------------------------------------

                                               Class  Probability score
                                         toilet seat              0.171
                                          safety pin              0.060
bannister, banister, balustrade, balusters, handrail              0.039
                                              bubble              0.035
--------------------------------------------------------------

                  Class  Probability score
                   vase              0.078
                thimble              0.074
             plate rack              0.049
saltshaker, salt shaker              0.047
--------------------------------------------------------------

                      Class  Probability score
           pizza, pizza pie              0.998
frying pan, frypan, skillet              0.001
                     potpie              0.000
                French loaf              0.000
--------------------------------------------------------------

              Class  Probability score
     patio, terrace              0.213
           fountain              0.164
lakeside, lakeshore              0.097
            sundial              0.088
--------------------------------------------------------------

Pre-trained models out-of-the-box


  • Here we are using pre-trained models out-of-the-box.
  • Can we use pre-trained models for our own classification problem with our classes?
  • Yes!! We have two options here:
    1. Add some extra layers to the pre-trained network to suit our particular task
    2. Pass training data through the network and save the output to use as features for training some other model

Pre-trained models to extract features


  • Let’s use pre-trained models to extract features.
  • We will pass our specific data through a pre-trained network to get a feature vector for each example in the data.
  • The feature vector is usually extracted from the last layer, before the classification layer from the pre-trained network.
  • You can think of each layer a transformer applying some transformations on the input received to that later.

Pre-trained models to extract features


  • Once we extract these feature vectors for all images in our training data, we can train a machine learning classifier such as logistic regression or random forest.
  • This classifier will be trained on our classes using feature representations extracted from the pre-trained models.
  • Let’s try this out.
  • It’s better to train such models with GPU. Since our dataset is quite small, we won’t have problems running it on a CPU.

Pre-trained models to extract features


Let’s look at some sample images in the dataset.

Dataset statistics


Here is the stat of our toy dataset.

Classes: ['beet_salad', 'chocolate_cake', 'edamame', 'french_fries', 'pizza', 'spring_rolls', 'sushi']
Class count: 40, 38, 40
Samples: 283
First sample: ('data/food/train/beet_salad/104294.jpg', 0)

Extract features (embeddings)


  • Now for each image in our dataset, we’ll extract a feature vector from a pre-trained model called densenet121, which is trained on the ImageNet dataset.

Shape of the embeddings


  • Now we have extracted feature vectors for all examples. What’s the shape of these features or embeddings?
torch.Size([283, 1024])
  • The size of each feature vector is 1024 because the size of the last layer in densenet architecture is 1024.

Source

Embeddings given by densenet

 

  • Let’s examine the feature vectors.
0 1 2 3 4 5 6 7 8 9 ... 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
0 0.000290 0.003821 0.005015 0.001307 0.052690 0.063403 0.000626 0.001850 0.256254 0.000223 ... 0.229935 1.046375 2.241259 0.229641 0.033674 0.742792 1.338698 2.130880 0.625475 0.463088
1 0.000407 0.005973 0.003206 0.001932 0.090702 0.438523 0.001513 0.003906 0.166081 0.000286 ... 0.910680 1.580815 0.087191 0.606904 0.436106 0.306456 0.940102 1.159818 1.712705 1.624753
2 0.000626 0.005090 0.002887 0.001299 0.091715 0.548537 0.000491 0.003587 0.266537 0.000408 ... 0.465152 0.678276 0.946387 1.194697 2.537747 1.642383 0.701200 0.115620 0.186433 0.166605
3 0.000169 0.006087 0.002489 0.002167 0.087537 0.623212 0.000427 0.000226 0.460680 0.000388 ... 0.394083 0.700158 0.105200 0.856323 0.038457 0.023948 0.131838 1.296370 0.723323 1.915215
4 0.000286 0.005520 0.001906 0.001599 0.186034 0.850148 0.000835 0.003025 0.036309 0.000142 ... 3.313760 0.565744 0.473564 0.139446 0.029283 1.165938 0.442319 0.227593 0.884266 1.592698

5 rows Ă— 1024 columns

  • The features are hard to interpret but they have some important information about the images which can be useful for classification.

Logistic regression with the extracted features


  • Let’s try out logistic regression on these extracted features.
Training score:  1.0
Validation score:  0.835820895522388
  • This is great accuracy for so little data and little effort!!!

Sample predictions


Let’s examine some sample predictions on the validation set.

Let’s cluster images!!


What if we don’t have any labels and we want to cluster these images?

K-Means on food dataset


densenet = models.densenet121(weights="DenseNet121_Weights.IMAGENET1K_V1")
densenet.classifier = torch.nn.Identity()  # remove that last "classification" layer
Z_food = get_features_unsup(densenet, food_inputs)
k = 5
km = KMeans(n_clusters=k, n_init='auto', random_state=123)
km.fit(Z_food)
KMeans(n_clusters=5, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Examining food clusters


for cluster in range(k):
    get_cluster_images(km, Z_food, X_food, cluster, n_img=6)
135
Image indices:  [135  25 133 115  17  93]

55
Image indices:  [55 20 68  1 47 13]

170
Image indices:  [170 115  72   5 161  86]

95
Image indices:  [ 95 233 136 245 266 149]

4
Image indices:  [  4 143  66 167  84 271]

Object detection


  • Another useful task and tool to know is object detection using YOLO model.
  • Let’s identify objects in a sample image using a pretrained model called YOLO8.
  • List the objects present in this image.

Object detection using YOLO


Let’s try this out using a pre-trained model.

from ultralytics import YOLO
model = YOLO("yolov8n.pt")  # pretrained YOLOv8n model

yolo_input = "data/yolo_test/3356700488_183566145b.jpg"
yolo_result = "data/yolo_result.jpg"
# Run batched inference on a list of images
result = model(yolo_input)  # return a list of Results objects
result[0].save(filename=yolo_result)

image 1/1 /Users/kvarada/EL/workshops/FoS-Intro-to-ML-2025/website/slides/data/yolo_test/3356700488_183566145b.jpg: 512x640 4 persons, 2 cars, 1 stop sign, 74.5ms
Speed: 2.3ms preprocess, 74.5ms inference, 11.6ms postprocess per image at shape (1, 3, 512, 640)
'data/yolo_result.jpg'

Object detection output


Sentence embeddings

Document clustering

Let’s cluster these recipe names:

0                            i yam what i yam  muffins
1                              to your health  muffins
2                        250 00 chocolate chip cookies
3                                        lplermagronen
4                              california roll   salad
5                                 chef salad  dressing
6                         ma s  oatmeal cake and icing
7              real new york style  cheesecake supreme
8                          aunt johnnie s   pound cake
9                              buffalo wing  mushrooms
10                                   funny bones  cake
11                                         green  soup
12      i coulda had a stuffed pepper  stuffed peppers
13                       kelly s chinese cabbage salad
14                            lofthouse  sugar cookies
15                                    marco    polenta
16                                     oh boy  waffles
17          oops  there it is   chocolate cake low fat
18                                        orange  soup
19              peanut butter   jam sandwich   muffins
20     perfect chocolate cake  mccall s cooking school
21      real  strawberry pie with french cream topping
22                                     ritz  y chicken
23                           starbucks  oat fudge bars
24                                            the cake
25                           the heavy one  cheesecake
26                                the ultimate brownie
27                                            57 chevy
28                                 nilla wafer martini
29                                spicy  pasta fagioli
30                       guaranteed 5 star banana cake
31                                   spaghetti squares
32                                        007 cocktail
33                       1 lb abm hawaiian sweet bread
34                                1 2 3 4 cake  orange
35                               1 bowl fudgy brownies
36    10 minute baked halibut with garlic butter sauce
37          100  honey whole wheat cracked wheat bread
38                               100  parmesan chicken
39                        100  whole grain wheat bread
40                             15 minute taco in a pan
41                              16 bean soup  crockpot
42                                            17 twist
43               1880 chocolate spice cake  with icing
44                          1905 salad dressing recipe
45                      1930 recipe for mincemeat cake
46                                     1950 s meatloaf
47                                        2 bean chili
48                        20 minute applesauce cookies
49                                 20 pound cheesecake
Name: name, dtype: object

Recipe names embeddings

from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer('all-MiniLM-L6-v2')
recipe_names = recipes_df["name"].tolist()
embeddings = embedder.encode(recipe_names)
recipe_names_embeddings = pd.DataFrame(
    embeddings,
    index=recipes_df.index,
)
recipe_names_embeddings.head()
0 1 2 3 4 5 6 7 8 9 ... 374 375 376 377 378 379 380 381 382 383
0 0.019592 -0.088336 0.072677 -0.034575 -0.048741 -0.049801 0.175334 -0.055191 0.020301 0.019828 ... 0.063293 -0.067171 0.087499 -0.061550 0.039297 -0.050147 0.027708 0.056843 0.056151 -0.122506
1 -0.000567 -0.011825 0.073199 0.058176 0.031688 -0.015428 0.168134 0.000466 0.033078 -0.013923 ... -0.012926 -0.015949 0.031315 -0.059074 0.014143 -0.047270 0.007844 0.035501 0.076061 -0.078119
2 -0.022604 0.065034 -0.033065 0.014450 -0.105039 -0.050559 0.100076 0.022929 -0.037398 0.011857 ... 0.007971 -0.019165 0.004935 0.009005 0.000919 -0.040078 0.008650 -0.075781 -0.083477 -0.123240
3 -0.066915 0.025988 -0.087689 -0.006847 -0.012861 0.049035 0.035351 0.124966 -0.011697 -0.050179 ... -0.042345 -0.005794 -0.031800 0.120664 -0.057335 -0.077068 0.001653 -0.048223 0.116455 0.021789
4 -0.007068 -0.007308 -0.026629 -0.004153 -0.052810 0.011126 0.024000 -0.036993 0.023526 -0.046870 ... -0.018432 0.051918 0.036101 -0.035312 0.005817 0.101802 -0.063171 -0.007917 0.089744 0.006997

5 rows Ă— 384 columns

Cluster embeddings

np.random.seed(42)
km_labels_dict = {
    k: KMeans(k, n_init='auto', random_state=123).fit(embeddings).predict(embeddings)
    for k in np.arange(4, 15)
}

Multimodal (CLIP)

  • Input: image + several possible captions
  • CLIP computes similarity scores between image and each caption
  • Enables:
    • Zero-shot classification
    • Text-to-image search
    • Image-to-text retrieval

Which text string is this picture most similar to?

  • Caption 1: a girl walking in the trails
  • Caption 2: two cats playing

model, preprocess = clip.load("ViT-B/32", device="cpu")
image = preprocess(Image.open("img/cats_playing.jpg")).unsqueeze(0)
text = clip.tokenize(["a girl walking in the trails", "two cats playing", ])
with torch.no_grad():
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1)
probs
tensor([[1.6534e-08, 1.0000e+00]])

Captions and images

  • Caption 1: A girl walking in the trails
  • Caption 2: A woman walking her dog on a leash
  • Caption 3: Two cats playing
  • Caption 4: A group of people hiking
  • Caption 5: Canada wins hockey
  • Caption 6: Cherry blossoms on UBC campus

Similarities between captions and images

Real-world scientific use cases

  • RNA-seq clustering: gene expression \(\rightarrow\) embedding \(\rightarrow\) UMAP
  • Rock photo classification: ResNet for mineral detection
  • Literature recommendation: MiniLM for similarity search
  • Chemical reaction prediction: ChemBERTa for SMILES strings

Reflection and discussion

✨ We hope you now see how embeddings and pretrained models can power up your research.

Prompt: What kinds of data in your research could be turned into embeddings?

  • Text? Images? Molecules? Signals?
  • What would you want to compare, cluster, or visualize?

Wrap-up & resources

  • Embeddings let you encode complex input into useful numbers
  • Pretrained models save time, data, and compute
  • These tools are ready to use, even for small datasets

Resources: