top of page

Fungal Detection in Vine Images: Using Google’s ViT-Base Patch-16 Vision Transformer

Introduction

In this comprehensive tutorial, we'll build a binary image classification system to detect fungal infections in microscopy images of vine wood. We'll use Vision Transformers (ViT), a state-of-the-art deep learning architecture that applies transformer concepts to image classification.


ree





Dataset Overview

Dataset: "An Eye on the Vine"

This dataset comes from research on pathogen segmentation in vinewood fluorescence microscopy images.





Dataset Composition

The full research dataset includes:


Dataset A: 427 synthetic images created by blending healthy wood snippets with vessel images from the DRIVE database (Digital Retinal Images for Vessel Extraction)


Dataset B (Our focus):

  • B1: 247 images with fungi in lower quality (higher blur)

  • B2: 312 images without fungi

  • B3: 569 images with fungi


Dataset C: 128 mixed images from A, B1, and B3


For this project, we use B2 (No Fungi) and B3 (Fungi) for binary classification.






Import Dependencies


import os
from pathlib import Path
import torch
from torch.utils.data import Dataset
from PIL import Image
from datasets import Dataset, Features, ClassLabel, Image as HFImage
from transformers import (
    ViTFeatureExtractor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer
)
import numpy as np
from datasets import load_metric
from sklearn.metrics import f1_score, classification_report

import matplotlib.pyplot as plt
from PIL import Image
import random


Key imports explained:

  • Path: Modern Python path handling

  • PIL.Image: Image loading and processing

  • ViTFeatureExtractor: Preprocesses images for Vision Transformer

  • ViTForImageClassification: Pre-trained ViT model for classification

  • Trainer: HuggingFace's high-level training API




Data Exploration and Visualization


Before training, we need to understand our data. Let's create three visualization functions.


Step 1: Visualize Sample Images


def visualize_fungi_samples(data_dir, num_samples=4):
    """
    Visualize sample images from both classes in the fungi dataset

    Parameters:
    data_dir (str): Path to the dataset directory
    num_samples (int): Number of samples to visualize per class
    """
    data_dir = Path(data_dir)
    classes = ['No Fungi', 'Fungi']

    # Create a figure with subplots
    fig, axes = plt.subplots(2, num_samples, figsize=(16, 8))
    fig.suptitle('Sample Images from Fungi Dataset', fontsize=16)

    # Loop through each class
    for class_idx, class_name in enumerate(classes):
        # Get paths to all images in this class
        class_dir = data_dir / class_name
        image_paths = list(class_dir.glob('*.tif'))

        # Select random samples
        selected_samples = random.sample(image_paths, min(num_samples, len(image_paths)))

        # Plot each sample
        for i, img_path in enumerate(selected_samples):
            with Image.open(img_path) as img:
                # Convert to RGB as some images might be grayscale
                if img.mode != 'RGB':
                    img = img.convert('RGB')

                # Display the image
                axes[class_idx, i].imshow(img)
                axes[class_idx, i].set_title(f"{class_name}\n{img_path.name}")
                axes[class_idx, i].axis('off')

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()


What this does:

  • Creates a 2-row grid (one row per class)

  • Randomly selects sample images from each class folder

  • Converts grayscale images to RGB (ViT requires 3 channels)

  • Displays images with their filenames


This function visualizes example images from a fungi image dataset to give a quick qualitative overview of both classes. It takes a dataset directory and randomly selects a specified number of .tif images from each class folder (“No Fungi” and “Fungi”). The function creates a grid of subplots with one row per class and displays the selected images, converting them to RGB if they are grayscale to ensure correct visualization. Each image is shown with its class label and filename, the layout is adjusted for readability, and the final figure is displayed to help users visually inspect differences between the two classes.


Usage:

data_dir = "an_eye_on_the_vine"

# Run the visualizations
print("Visualizing sample images...")
visualize_fungi_samples(data_dir, num_samples=4)
ree



Analyze Image Properties


def analyze_image_properties(data_dir):
    """
    Analyze basic properties of the images in the dataset

    Parameters:
    data_dir (str): Path to the dataset directory
    """
    data_dir = Path(data_dir)
    classes = ['No Fungi', 'Fungi']

    # Initialize counters and data collectors
    class_counts = {}
    image_sizes = []
    image_modes = {}

    sample_size = 150
    print(f'Testing on {sample_size * 2} sample images.')
    print('')
    # Loop through each class
    for class_name in classes:
        class_dir = data_dir / class_name
        image_paths = list(class_dir.glob('*.tif'))
        class_counts[class_name] = len(image_paths)

        # Analyze a sample of images from each class
        samples = random.sample(image_paths, min(sample_size, len(image_paths)))
        for img_path in samples:
            with Image.open(img_path) as img:
                # Record image size
                image_sizes.append(img.size)

                # Record image mode
                if img.mode in image_modes:
                    image_modes[img.mode] += 1
                else:
                    image_modes[img.mode] = 1

    # Print analysis results
    print("Dataset Analysis:")
    print("=" * 35)
    print(f"Total classes: {len(classes)}")
    print("")
    print("Class distribution:")
    for class_name, count in class_counts.items():
        print(f"  - {class_name}: {count} images")
    print("")
    # Analyze image sizes
    if image_sizes:
        unique_sizes = set(image_sizes)
        print(f"Unique image sizes: {len(unique_sizes)}")
        for size in unique_sizes:
            count = image_sizes.count(size)
            print(f"  - {size}: {count} images")
    print("")
    # Analyze image modes
    print("Image color modes:")
    for mode, count in image_modes.items():
        print(f"  - {mode}: {count} images")

What this does:

  • Counts images in each class

  • Analyzes image dimensions

  • Checks color modes (RGB, grayscale, etc.)

  • Helps identify any preprocessing needs


This function performs a basic exploratory analysis of a fungi image dataset by examining class balance and image characteristics. It scans the dataset directories for the two classes (“No Fungi” and “Fungi”), counts the total number of images per class, and randomly samples up to 150 images from each class for analysis. For these sampled images, it records image dimensions and color modes (such as RGB or grayscale). The function then prints a summary showing the number of classes, class distribution, the number and frequency of unique image sizes, and the distribution of image color modes, helping assess dataset consistency and potential preprocessing needs.


Usage:

print("\nAnalyzing image properties...\n")
analyze_image_properties(data_dir)

Output:


Analyzing image properties...

Testing on 300 sample images.

Dataset Analysis:
===================================
Total classes: 2

Class distribution:
  - No Fungi: 312 images
  - Fungi: 569 images

Unique image sizes: 1
  - (256, 256): 300 images

Image color modes:
  - RGB: 300 images



Plot Color Histograms


def plot_color_histograms(data_dir, num_samples=3):
    """
    Plot RGB histograms for sample images from each class

    Parameters:
    data_dir (str): Path to the dataset directory
    num_samples (int): Number of samples to analyze per class
    """
    data_dir = Path(data_dir)
    classes = ['No Fungi', 'Fungi']

    fig, axes = plt.subplots(2, num_samples, figsize=(16, 10))
    fig.suptitle('RGB Histograms of Sample Images', fontsize=16)

    # Loop through each class
    for class_idx, class_name in enumerate(classes):
        class_dir = data_dir / class_name
        image_paths = list(class_dir.glob('*.tif'))
        selected_samples = random.sample(image_paths, min(num_samples, len(image_paths)))

        # For each sample, plot the histogram
        for i, img_path in enumerate(selected_samples):
            with Image.open(img_path) as img:
                # Convert to RGB
                if img.mode != 'RGB':
                    img = img.convert('RGB')

                img_array = np.array(img)
                ax = axes[class_idx, i]

                # Plot histograms for each channel
                ax.hist(img_array[:,:,0].flatten(), bins=50, color='r', alpha=0.5, label='Red')
                ax.hist(img_array[:,:,1].flatten(), bins=50, color='g', alpha=0.5, label='Green')
                ax.hist(img_array[:,:,2].flatten(), bins=50, color='b', alpha=0.5, label='Blue')

                ax.set_title(f"{class_name}\n{img_path.name}")
                ax.legend()
                ax.set_xlabel('Pixel Value')
                ax.set_ylabel('Frequency')

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()



print("\nPlotting color histograms...")
plot_color_histograms(data_dir, num_samples=3)

Output:


ree


This function visually compares color distributions in the fungi dataset by plotting RGB pixel-value histograms for sample images from each class. It randomly selects a specified number of .tif images from the “No Fungi” and “Fungi” folders, converts each image to RGB if needed, and computes histograms for the red, green, and blue channels. These histograms are displayed in a grid of subplots, with one row per class and one column per sampled image, allowing quick visual inspection of color intensity patterns and differences between classes that may be informative for model training.


What this does:

  1. Converts each image to a NumPy array

  2. Separates RGB channels

  3. Plots pixel value distribution for each channel

  4. Shows how color intensities differ between classes


Key observations from histograms:

  • No Fungi images: Often show concentrated distributions in lower pixel values (darker images)

  • Fungi images: May show different distributions due to fluorescence patterns

  • These color patterns help the model distinguish between classes






Building the Dataset Loader

Create Custom Dataset Class



# Custom dataset class to load TIF images
class FungiDataset:
    def __init__(self, data_dir):
        self.data_dir = Path(data_dir)
        self.images = []
        self.labels = []
        self.label_map = {'No Fungi': 0, 'Fungi': 1}

        # Load all images and labels
        for label in self.label_map.keys():
            label_dir = self.data_dir / label
            for img_path in label_dir.glob('*.tif'):
                self.images.append(str(img_path))
                self.labels.append(self.label_map[label])

    def to_huggingface_dataset(self):
        # Convert to HuggingFace dataset format
        def load_image(path):
            with Image.open(path) as img:
                # Convert to RGB as ViT expects 3 channels
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                return img

        dataset_dict = {
            'img': [load_image(path) for path in self.images],
            'label': self.labels
        }

        features = Features({
            'img': HFImage(),
            'label': ClassLabel(names=['No Fungi', 'Fungi'])
        })

        return Dataset.from_dict(dataset_dict, features=features)

This custom dataset class loads a fungi image dataset organized into “No Fungi” and “Fungi” folders and prepares it for use with Hugging Face tools. When initialized, it scans the dataset directory, collects paths to all .tif images, and assigns numeric labels based on a predefined label map. The to_huggingface_dataset method then loads each image, converts it to RGB to meet Vision Transformer input requirements, and constructs a Hugging Face Dataset object with image and label features. This enables seamless integration with Hugging Face preprocessing, training, and evaluation pipelines.



Breaking it down:


`__init__` method:

  • Scans the directory structure

  • Maps folder names to numeric labels (No Fungi = 0, Fungi = 1)

  • Stores image paths and corresponding labels



`to_huggingface_dataset` method:

  • Loads all images into memory

  • Ensures RGB format (ViT requirement)

  • Creates a HuggingFace Dataset object with proper features:

    • HFImage(): Specialized image feature

    • ClassLabel: Categorical label with class names



Why HuggingFace format?

  • Seamless integration with Transformers library

  • Built-in data preprocessing and batching

  • Easy train/test splitting






Preparing the Vision Transformer


Understanding the Preprocessing Pipeline



# Function to preprocess images for ViT
def preprocess(batch, feature_extractor):
    inputs = feature_extractor(
        batch['img'],
        return_tensors='pt'
    )
    inputs['label'] = batch['label']
    return inputs

What the feature extractor does:

  1. Resizes images to 224x224 (ViT standard input size)

  2. Normalizes pixel values using ImageNet statistics

  3. Converts to PyTorch tensors

  4. Adds label information to the processed batch




Create Collate Function for Batching



# Collate function for batching
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

Purpose:

  • Combines individual samples into batches

  • Stacks pixel values into a single tensor (batch_size, 3, 224, 224)

  • Creates a label tensor (batch_size,)




Define Comprehensive Metrics


# Compute metrics function with F1 score
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    references = eval_pred.label_ids

    # Calculate metrics
    accuracy_metric = load_metric("accuracy")
    accuracy = accuracy_metric.compute(predictions=predictions, references=references)

    # Calculate F1 score
    f1 = f1_score(references, predictions, average='weighted')

    # Get detailed classification report
    report = classification_report(references, predictions,
                                 target_names=['No Fungi', 'Fungi'],
                                 output_dict=True)

    # Return all metrics
    return {
        'accuracy': accuracy['accuracy'],
        'f1_score': f1,
        'f1_no_fungi': report['No Fungi']['f1-score'],
        'f1_fungi': report['Fungi']['f1-score'],
        'precision_no_fungi': report['No Fungi']['precision'],
        'precision_fungi': report['Fungi']['precision'],
        'recall_no_fungi': report['No Fungi']['recall'],
        'recall_fungi': report['Fungi']['recall']
    }


This function computes evaluation metrics for the fungi classification model during validation or testing. It converts the model’s raw prediction scores into class labels using argmax and compares them with the true labels. The function calculates overall accuracy using Hugging Face’s accuracy metric and computes a weighted F1 score to account for class imbalance. It also generates a detailed classification report to extract per-class precision, recall, and F1 scores for both “No Fungi” and “Fungi.” All these metrics are returned in a dictionary, enabling comprehensive performance tracking during model training and evaluation.



Metrics explained

  1. Accuracy: Overall correct predictions (94.9% achieved)

  2. F1 Score: Harmonic mean of precision and recall

    1. Weighted F1: 0.949 (accounts for class imbalance)

  3. Per-class metrics:

    1. F1 No Fungi: 0.934 (93.4% balanced performance)

    2. F1 Fungi: 0.963 (96.3% balanced performance)

  4. Precision: Of predicted positives, how many are correct?

    1. Precision Fungi: 0.975 (97.5% of fungi predictions are correct)

  5. Recall: Of actual positives, how many did we find?

    1. Recall Fungi: 0.952 (95.2% of actual fungi cases detected)



Why these metrics matter:

  • Accuracy alone can be misleading with imbalanced data

  • F1 score balances false positives and false negatives

  • Class-specific metrics ensure both classes perform well






Training the Model


Main Training Function



def train_fungi_classifier(data_dir, output_dir="./fungi_model", num_epochs=4):
    # Load and prepare datasets
    dataset = FungiDataset(data_dir)
    full_dataset = dataset.to_huggingface_dataset()

    # Split into train and test
    full_dataset = full_dataset.shuffle(seed=42)
    split_dataset = full_dataset.train_test_split(test_size=0.2)

    # Load ViT feature extractor
    model_id = 'google/vit-base-patch16-224-in21k'
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)

    # Prepare datasets
    def preprocess_function(batch):
        return preprocess(batch, feature_extractor)

    prepared_train = split_dataset['train'].with_transform(preprocess_function)
    prepared_test = split_dataset['test'].with_transform(preprocess_function)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=16,
        evaluation_strategy="steps",
        num_train_epochs=num_epochs,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        load_best_model_at_end=True,
    )
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=16,
        evaluation_strategy="epoch",     # Changed to epoch
        save_strategy="epoch",          # Changed to match evaluation_strategy
        num_train_epochs=num_epochs,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        load_best_model_at_end=True,
    )
    # Load model
    model = ViTForImageClassification.from_pretrained(
        model_id,
        num_labels=2,  # Binary classification
        ignore_mismatched_sizes=True
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_train,
        eval_dataset=prepared_test,
        tokenizer=feature_extractor,
    )

    # Train model
    train_results = trainer.train()

    # Save model and evaluate
    trainer.save_model()
    metrics = trainer.evaluate(prepared_test)

    return trainer, metrics

This function trains a binary image classification model to distinguish fungi images using a Vision Transformer (ViT). It loads image data from a directory, converts it into a Hugging Face dataset, shuffles it, and splits it into training and testing sets. The code uses a pre-trained ViT feature extractor to preprocess images and applies this preprocessing dynamically during training. It then configures training parameters such as batch size, learning rate, number of epochs, and evaluation strategy, and loads a pre-trained ViT model adapted for two classes. Using Hugging Face’s Trainer API, the function trains the model, evaluates it on the test set, saves the best-performing model, and finally returns the trained trainer object along with the evaluation metrics.


Breaking down the setup:

  1. Dataset creation: Load TIF images into HuggingFace format

  2. Shuffling: Random shuffle with seed 42 (reproducibility)

  3. Train/test split: 80% training, 20% testing

  4. Model selection: `google/vit-base-patch16-224-in21k`

    1. vit-base: Medium-sized model (86M parameters)

    2. patch16: Divides images into 16x16 patches

    3. 224: Input size 224x224 pixels

    4. in21k: Pre-trained on ImageNet-21k (14M images, 21k classes)




Configure Training Arguments

    

training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=16,
        evaluation_strategy="epoch",     # Changed to epoch
        save_strategy="epoch",          # Changed to match evaluation_strategy
        num_train_epochs=num_epochs,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        load_best_model_at_end=True,
    )

This code configures the training setup for a Hugging Face model using the TrainingArguments class. It specifies where to save model outputs, sets the training batch size per device, and defines the number of training epochs and learning rate. The model is evaluated and checkpointed at the end of each epoch, with only the two most recent checkpoints retained to save disk space. Logging occurs every ten steps, unused dataset columns are preserved to support custom preprocessing, model uploads to the Hugging Face Hub are disabled, and the best-performing model (based on evaluation metrics) is automatically reloaded at the end of training.



Key hyperparameters explained:

  • batch_size=16: Process 16 images at a time

  • evaluation_strategy="epoch": Check performance after each complete pass

  • num_train_epochs=4: Four complete passes through the data

  • learning_rate=2e-4: Step size for weight updates

  • load_best_model_at_end=True: Automatically loads the checkpoint with best validation performance




Initialize Model and Trainer

    

# Load model
    model = ViTForImageClassification.from_pretrained(
        model_id,
        num_labels=2,  # Binary classification
        ignore_mismatched_sizes=True
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_train,
        eval_dataset=prepared_test,
        tokenizer=feature_extractor,
    )

    # Train model
    train_results = trainer.train()

    # Save model and evaluate
    trainer.save_model()
    metrics = trainer.evaluate(prepared_test)

    return trainer, metrics

What happens during training:


1. Model initialization:

  • Loads pre-trained weights from ImageNet-21k

  • Replaces classification head with 2-class output

  • ignore_mismatched_sizes=True: Allows head replacement



2. Trainer setup:

  • Handles training loop, gradient computation, optimizer

  • Automatically logs metrics

  • Manages checkpointing



3. Training process:

  • Forward pass: Image → ViT → Predictions

  • Loss computation: Cross-entropy between predictions and labels

  • Backward pass: Compute gradients

  • Optimizer step: Update weights

  • Repeat for all batches and epochs



4. Evaluation:

  • Run model on test set

  • Compute all metrics (accuracy, F1, precision, recall)




Run Training with Experiment Tracking


Login to Weights & Biases (optional but recommended)


import wandb
wandb.login(key="<wandb_key>")
wandb.init(project="vt_001")

This code integrates Weights & Biases (W&B) for experiment tracking and logging. It first authenticates the user with W&B using an API key, then initializes a new run under the project named "vt_001". Once initialized, any compatible training process (such as Hugging Face’s Trainer) can automatically log metrics, losses, model parameters, and other training artifacts to W&B, enabling real-time monitoring, comparison of experiments, and reproducibility through a centralized dashboard.


Weights & Biases (wandb):

  • Tracks experiments automatically

  • Logs metrics, hyperparameters, system info

  • Provides visualization dashboards

  • Optional but highly recommended for ML projects



Run training


if __name__ == "__main__":
    data_dir = "/kaggle/input/an-eye-on-the-vine/an_eye_on_the_vine"
    trainer, metrics = train_fungi_classifier(data_dir)
    print(f"Evaluation metrics: {metrics}")






Results Analysis


Training Progression

ree


Key observations:

  • Steady improvement: Accuracy improves from 90% to 92% across epochs

  • Good generalization: Small gap between training and validation loss

  • High precision on Fungi: 92% precision means very few false positives




Interpreting the Color Histograms

No Fungi samples:

  • Often concentrated in lower pixel ranges (darker overall)

  • More uniform distributions across channels

  • Red channel sometimes shows peaks at low values



Fungi samples:

  • Broader distributions due to fluorescence

  • Green/Blue channels often dominant (fluorescence marker)

  • Multiple peaks indicate varied tissue structures






Conclusion

What We Achieved

  • Built an end-to-end image classification pipeline

  • Achieved over 90% accuracy on fungi detection

  • Maintained balanced performance across both classes

  • Created reusable visualization and analysis tools


Full code is available at:






Transform Your AI Workflows with Codersarts

Whether you're building intelligent systems with AI implementing RAG for smart information retrieval, or developing robust multi-agent architectures, the experts at Codersarts are here to support your vision. From academic prototypes to enterprise-grade solutions, we provide:


  • Custom RAG Implementation: Build retrieval-augmented generation systems tailored to your domain

  • AI-Based Agent Systems: Design and deploy modular, coordinated AI agents

  • End-to-End AI Development: From setup and orchestration to deployment and optimization


Do not let architectural complexity or tooling challenges slow down your progress. Partner with Codersarts and bring your next-generation AI systems to life.


Ready to get started? Visit Codersarts.com or connect with our team to discuss your AI-based project.The future of modular, intelligent automation is here – let’s build it together!


ree

Comments


bottom of page