Fungal Detection in Vine Images: Using Google’s ViT-Base Patch-16 Vision Transformer
- ganesh90
- 7 hours ago
- 11 min read
Introduction
In this comprehensive tutorial, we'll build a binary image classification system to detect fungal infections in microscopy images of vine wood. We'll use Vision Transformers (ViT), a state-of-the-art deep learning architecture that applies transformer concepts to image classification.

Dataset Overview
Dataset: "An Eye on the Vine"
This dataset comes from research on pathogen segmentation in vinewood fluorescence microscopy images.
The dataset is available at: https://archive.ics.uci.edu/dataset/966/an+eye+on+the+vine+-+a+dataset+for+fungi+segmentation+in+mi
Dataset Composition
The full research dataset includes:
Dataset A: 427 synthetic images created by blending healthy wood snippets with vessel images from the DRIVE database (Digital Retinal Images for Vessel Extraction)
Dataset B (Our focus):
B1: 247 images with fungi in lower quality (higher blur)
B2: 312 images without fungi
B3: 569 images with fungi
Dataset C: 128 mixed images from A, B1, and B3
For this project, we use B2 (No Fungi) and B3 (Fungi) for binary classification.
Import Dependencies
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset
from PIL import Image
from datasets import Dataset, Features, ClassLabel, Image as HFImage
from transformers import (
ViTFeatureExtractor,
ViTForImageClassification,
TrainingArguments,
Trainer
)
import numpy as np
from datasets import load_metric
from sklearn.metrics import f1_score, classification_report
import matplotlib.pyplot as plt
from PIL import Image
import random
Key imports explained:
Path: Modern Python path handling
PIL.Image: Image loading and processing
ViTFeatureExtractor: Preprocesses images for Vision Transformer
ViTForImageClassification: Pre-trained ViT model for classification
Trainer: HuggingFace's high-level training API
Data Exploration and Visualization
Before training, we need to understand our data. Let's create three visualization functions.
Step 1: Visualize Sample Images
def visualize_fungi_samples(data_dir, num_samples=4):
"""
Visualize sample images from both classes in the fungi dataset
Parameters:
data_dir (str): Path to the dataset directory
num_samples (int): Number of samples to visualize per class
"""
data_dir = Path(data_dir)
classes = ['No Fungi', 'Fungi']
# Create a figure with subplots
fig, axes = plt.subplots(2, num_samples, figsize=(16, 8))
fig.suptitle('Sample Images from Fungi Dataset', fontsize=16)
# Loop through each class
for class_idx, class_name in enumerate(classes):
# Get paths to all images in this class
class_dir = data_dir / class_name
image_paths = list(class_dir.glob('*.tif'))
# Select random samples
selected_samples = random.sample(image_paths, min(num_samples, len(image_paths)))
# Plot each sample
for i, img_path in enumerate(selected_samples):
with Image.open(img_path) as img:
# Convert to RGB as some images might be grayscale
if img.mode != 'RGB':
img = img.convert('RGB')
# Display the image
axes[class_idx, i].imshow(img)
axes[class_idx, i].set_title(f"{class_name}\n{img_path.name}")
axes[class_idx, i].axis('off')
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()What this does:
Creates a 2-row grid (one row per class)
Randomly selects sample images from each class folder
Converts grayscale images to RGB (ViT requires 3 channels)
Displays images with their filenames
This function visualizes example images from a fungi image dataset to give a quick qualitative overview of both classes. It takes a dataset directory and randomly selects a specified number of .tif images from each class folder (“No Fungi” and “Fungi”). The function creates a grid of subplots with one row per class and displays the selected images, converting them to RGB if they are grayscale to ensure correct visualization. Each image is shown with its class label and filename, the layout is adjusted for readability, and the final figure is displayed to help users visually inspect differences between the two classes.
Usage:
data_dir = "an_eye_on_the_vine"
# Run the visualizations
print("Visualizing sample images...")
visualize_fungi_samples(data_dir, num_samples=4)
Analyze Image Properties
def analyze_image_properties(data_dir):
"""
Analyze basic properties of the images in the dataset
Parameters:
data_dir (str): Path to the dataset directory
"""
data_dir = Path(data_dir)
classes = ['No Fungi', 'Fungi']
# Initialize counters and data collectors
class_counts = {}
image_sizes = []
image_modes = {}
sample_size = 150
print(f'Testing on {sample_size * 2} sample images.')
print('')
# Loop through each class
for class_name in classes:
class_dir = data_dir / class_name
image_paths = list(class_dir.glob('*.tif'))
class_counts[class_name] = len(image_paths)
# Analyze a sample of images from each class
samples = random.sample(image_paths, min(sample_size, len(image_paths)))
for img_path in samples:
with Image.open(img_path) as img:
# Record image size
image_sizes.append(img.size)
# Record image mode
if img.mode in image_modes:
image_modes[img.mode] += 1
else:
image_modes[img.mode] = 1
# Print analysis results
print("Dataset Analysis:")
print("=" * 35)
print(f"Total classes: {len(classes)}")
print("")
print("Class distribution:")
for class_name, count in class_counts.items():
print(f" - {class_name}: {count} images")
print("")
# Analyze image sizes
if image_sizes:
unique_sizes = set(image_sizes)
print(f"Unique image sizes: {len(unique_sizes)}")
for size in unique_sizes:
count = image_sizes.count(size)
print(f" - {size}: {count} images")
print("")
# Analyze image modes
print("Image color modes:")
for mode, count in image_modes.items():
print(f" - {mode}: {count} images")What this does:
Counts images in each class
Analyzes image dimensions
Checks color modes (RGB, grayscale, etc.)
Helps identify any preprocessing needs
This function performs a basic exploratory analysis of a fungi image dataset by examining class balance and image characteristics. It scans the dataset directories for the two classes (“No Fungi” and “Fungi”), counts the total number of images per class, and randomly samples up to 150 images from each class for analysis. For these sampled images, it records image dimensions and color modes (such as RGB or grayscale). The function then prints a summary showing the number of classes, class distribution, the number and frequency of unique image sizes, and the distribution of image color modes, helping assess dataset consistency and potential preprocessing needs.
Usage:
print("\nAnalyzing image properties...\n")
analyze_image_properties(data_dir)Output:
Analyzing image properties...
Testing on 300 sample images.
Dataset Analysis:
===================================
Total classes: 2
Class distribution:
- No Fungi: 312 images
- Fungi: 569 images
Unique image sizes: 1
- (256, 256): 300 images
Image color modes:
- RGB: 300 images
Plot Color Histograms
def plot_color_histograms(data_dir, num_samples=3):
"""
Plot RGB histograms for sample images from each class
Parameters:
data_dir (str): Path to the dataset directory
num_samples (int): Number of samples to analyze per class
"""
data_dir = Path(data_dir)
classes = ['No Fungi', 'Fungi']
fig, axes = plt.subplots(2, num_samples, figsize=(16, 10))
fig.suptitle('RGB Histograms of Sample Images', fontsize=16)
# Loop through each class
for class_idx, class_name in enumerate(classes):
class_dir = data_dir / class_name
image_paths = list(class_dir.glob('*.tif'))
selected_samples = random.sample(image_paths, min(num_samples, len(image_paths)))
# For each sample, plot the histogram
for i, img_path in enumerate(selected_samples):
with Image.open(img_path) as img:
# Convert to RGB
if img.mode != 'RGB':
img = img.convert('RGB')
img_array = np.array(img)
ax = axes[class_idx, i]
# Plot histograms for each channel
ax.hist(img_array[:,:,0].flatten(), bins=50, color='r', alpha=0.5, label='Red')
ax.hist(img_array[:,:,1].flatten(), bins=50, color='g', alpha=0.5, label='Green')
ax.hist(img_array[:,:,2].flatten(), bins=50, color='b', alpha=0.5, label='Blue')
ax.set_title(f"{class_name}\n{img_path.name}")
ax.legend()
ax.set_xlabel('Pixel Value')
ax.set_ylabel('Frequency')
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
print("\nPlotting color histograms...")
plot_color_histograms(data_dir, num_samples=3)
Output:

This function visually compares color distributions in the fungi dataset by plotting RGB pixel-value histograms for sample images from each class. It randomly selects a specified number of .tif images from the “No Fungi” and “Fungi” folders, converts each image to RGB if needed, and computes histograms for the red, green, and blue channels. These histograms are displayed in a grid of subplots, with one row per class and one column per sampled image, allowing quick visual inspection of color intensity patterns and differences between classes that may be informative for model training.
What this does:
Converts each image to a NumPy array
Separates RGB channels
Plots pixel value distribution for each channel
Shows how color intensities differ between classes
Key observations from histograms:
No Fungi images: Often show concentrated distributions in lower pixel values (darker images)
Fungi images: May show different distributions due to fluorescence patterns
These color patterns help the model distinguish between classes
Building the Dataset Loader
Create Custom Dataset Class
# Custom dataset class to load TIF images
class FungiDataset:
def __init__(self, data_dir):
self.data_dir = Path(data_dir)
self.images = []
self.labels = []
self.label_map = {'No Fungi': 0, 'Fungi': 1}
# Load all images and labels
for label in self.label_map.keys():
label_dir = self.data_dir / label
for img_path in label_dir.glob('*.tif'):
self.images.append(str(img_path))
self.labels.append(self.label_map[label])
def to_huggingface_dataset(self):
# Convert to HuggingFace dataset format
def load_image(path):
with Image.open(path) as img:
# Convert to RGB as ViT expects 3 channels
if img.mode != 'RGB':
img = img.convert('RGB')
return img
dataset_dict = {
'img': [load_image(path) for path in self.images],
'label': self.labels
}
features = Features({
'img': HFImage(),
'label': ClassLabel(names=['No Fungi', 'Fungi'])
})
return Dataset.from_dict(dataset_dict, features=features)
This custom dataset class loads a fungi image dataset organized into “No Fungi” and “Fungi” folders and prepares it for use with Hugging Face tools. When initialized, it scans the dataset directory, collects paths to all .tif images, and assigns numeric labels based on a predefined label map. The to_huggingface_dataset method then loads each image, converts it to RGB to meet Vision Transformer input requirements, and constructs a Hugging Face Dataset object with image and label features. This enables seamless integration with Hugging Face preprocessing, training, and evaluation pipelines.
Breaking it down:
`__init__` method:
Scans the directory structure
Maps folder names to numeric labels (No Fungi = 0, Fungi = 1)
Stores image paths and corresponding labels
`to_huggingface_dataset` method:
Loads all images into memory
Ensures RGB format (ViT requirement)
Creates a HuggingFace Dataset object with proper features:
HFImage(): Specialized image feature
ClassLabel: Categorical label with class names
Why HuggingFace format?
Seamless integration with Transformers library
Built-in data preprocessing and batching
Easy train/test splitting
Preparing the Vision Transformer
Understanding the Preprocessing Pipeline
# Function to preprocess images for ViT
def preprocess(batch, feature_extractor):
inputs = feature_extractor(
batch['img'],
return_tensors='pt'
)
inputs['label'] = batch['label']
return inputs
What the feature extractor does:
Resizes images to 224x224 (ViT standard input size)
Normalizes pixel values using ImageNet statistics
Converts to PyTorch tensors
Adds label information to the processed batch
Create Collate Function for Batching
# Collate function for batching
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['label'] for x in batch])
}
Purpose:
Combines individual samples into batches
Stacks pixel values into a single tensor (batch_size, 3, 224, 224)
Creates a label tensor (batch_size,)
Define Comprehensive Metrics
# Compute metrics function with F1 score
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
references = eval_pred.label_ids
# Calculate metrics
accuracy_metric = load_metric("accuracy")
accuracy = accuracy_metric.compute(predictions=predictions, references=references)
# Calculate F1 score
f1 = f1_score(references, predictions, average='weighted')
# Get detailed classification report
report = classification_report(references, predictions,
target_names=['No Fungi', 'Fungi'],
output_dict=True)
# Return all metrics
return {
'accuracy': accuracy['accuracy'],
'f1_score': f1,
'f1_no_fungi': report['No Fungi']['f1-score'],
'f1_fungi': report['Fungi']['f1-score'],
'precision_no_fungi': report['No Fungi']['precision'],
'precision_fungi': report['Fungi']['precision'],
'recall_no_fungi': report['No Fungi']['recall'],
'recall_fungi': report['Fungi']['recall']
}
This function computes evaluation metrics for the fungi classification model during validation or testing. It converts the model’s raw prediction scores into class labels using argmax and compares them with the true labels. The function calculates overall accuracy using Hugging Face’s accuracy metric and computes a weighted F1 score to account for class imbalance. It also generates a detailed classification report to extract per-class precision, recall, and F1 scores for both “No Fungi” and “Fungi.” All these metrics are returned in a dictionary, enabling comprehensive performance tracking during model training and evaluation.
Metrics explained
Accuracy: Overall correct predictions (94.9% achieved)
F1 Score: Harmonic mean of precision and recall
Weighted F1: 0.949 (accounts for class imbalance)
Per-class metrics:
F1 No Fungi: 0.934 (93.4% balanced performance)
F1 Fungi: 0.963 (96.3% balanced performance)
Precision: Of predicted positives, how many are correct?
Precision Fungi: 0.975 (97.5% of fungi predictions are correct)
Recall: Of actual positives, how many did we find?
Recall Fungi: 0.952 (95.2% of actual fungi cases detected)
Why these metrics matter:
Accuracy alone can be misleading with imbalanced data
F1 score balances false positives and false negatives
Class-specific metrics ensure both classes perform well
Training the Model
Main Training Function
def train_fungi_classifier(data_dir, output_dir="./fungi_model", num_epochs=4):
# Load and prepare datasets
dataset = FungiDataset(data_dir)
full_dataset = dataset.to_huggingface_dataset()
# Split into train and test
full_dataset = full_dataset.shuffle(seed=42)
split_dataset = full_dataset.train_test_split(test_size=0.2)
# Load ViT feature extractor
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
# Prepare datasets
def preprocess_function(batch):
return preprocess(batch, feature_extractor)
prepared_train = split_dataset['train'].with_transform(preprocess_function)
prepared_test = split_dataset['test'].with_transform(preprocess_function)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=num_epochs,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
load_best_model_at_end=True,
)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=16,
evaluation_strategy="epoch", # Changed to epoch
save_strategy="epoch", # Changed to match evaluation_strategy
num_train_epochs=num_epochs,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
load_best_model_at_end=True,
)
# Load model
model = ViTForImageClassification.from_pretrained(
model_id,
num_labels=2, # Binary classification
ignore_mismatched_sizes=True
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_train,
eval_dataset=prepared_test,
tokenizer=feature_extractor,
)
# Train model
train_results = trainer.train()
# Save model and evaluate
trainer.save_model()
metrics = trainer.evaluate(prepared_test)
return trainer, metrics
This function trains a binary image classification model to distinguish fungi images using a Vision Transformer (ViT). It loads image data from a directory, converts it into a Hugging Face dataset, shuffles it, and splits it into training and testing sets. The code uses a pre-trained ViT feature extractor to preprocess images and applies this preprocessing dynamically during training. It then configures training parameters such as batch size, learning rate, number of epochs, and evaluation strategy, and loads a pre-trained ViT model adapted for two classes. Using Hugging Face’s Trainer API, the function trains the model, evaluates it on the test set, saves the best-performing model, and finally returns the trained trainer object along with the evaluation metrics.
Breaking down the setup:
Dataset creation: Load TIF images into HuggingFace format
Shuffling: Random shuffle with seed 42 (reproducibility)
Train/test split: 80% training, 20% testing
Model selection: `google/vit-base-patch16-224-in21k`
vit-base: Medium-sized model (86M parameters)
patch16: Divides images into 16x16 patches
224: Input size 224x224 pixels
in21k: Pre-trained on ImageNet-21k (14M images, 21k classes)
Configure Training Arguments
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=16,
evaluation_strategy="epoch", # Changed to epoch
save_strategy="epoch", # Changed to match evaluation_strategy
num_train_epochs=num_epochs,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
load_best_model_at_end=True,
)
This code configures the training setup for a Hugging Face model using the TrainingArguments class. It specifies where to save model outputs, sets the training batch size per device, and defines the number of training epochs and learning rate. The model is evaluated and checkpointed at the end of each epoch, with only the two most recent checkpoints retained to save disk space. Logging occurs every ten steps, unused dataset columns are preserved to support custom preprocessing, model uploads to the Hugging Face Hub are disabled, and the best-performing model (based on evaluation metrics) is automatically reloaded at the end of training.
Key hyperparameters explained:
batch_size=16: Process 16 images at a time
evaluation_strategy="epoch": Check performance after each complete pass
num_train_epochs=4: Four complete passes through the data
learning_rate=2e-4: Step size for weight updates
load_best_model_at_end=True: Automatically loads the checkpoint with best validation performance
Initialize Model and Trainer
# Load model
model = ViTForImageClassification.from_pretrained(
model_id,
num_labels=2, # Binary classification
ignore_mismatched_sizes=True
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_train,
eval_dataset=prepared_test,
tokenizer=feature_extractor,
)
# Train model
train_results = trainer.train()
# Save model and evaluate
trainer.save_model()
metrics = trainer.evaluate(prepared_test)
return trainer, metrics
What happens during training:
1. Model initialization:
Loads pre-trained weights from ImageNet-21k
Replaces classification head with 2-class output
ignore_mismatched_sizes=True: Allows head replacement
2. Trainer setup:
Handles training loop, gradient computation, optimizer
Automatically logs metrics
Manages checkpointing
3. Training process:
Forward pass: Image → ViT → Predictions
Loss computation: Cross-entropy between predictions and labels
Backward pass: Compute gradients
Optimizer step: Update weights
Repeat for all batches and epochs
4. Evaluation:
Run model on test set
Compute all metrics (accuracy, F1, precision, recall)
Run Training with Experiment Tracking
Login to Weights & Biases (optional but recommended)
import wandb
wandb.login(key="<wandb_key>")
wandb.init(project="vt_001")
This code integrates Weights & Biases (W&B) for experiment tracking and logging. It first authenticates the user with W&B using an API key, then initializes a new run under the project named "vt_001". Once initialized, any compatible training process (such as Hugging Face’s Trainer) can automatically log metrics, losses, model parameters, and other training artifacts to W&B, enabling real-time monitoring, comparison of experiments, and reproducibility through a centralized dashboard.
Weights & Biases (wandb):
Tracks experiments automatically
Logs metrics, hyperparameters, system info
Provides visualization dashboards
Optional but highly recommended for ML projects
Run training
if __name__ == "__main__":
data_dir = "/kaggle/input/an-eye-on-the-vine/an_eye_on_the_vine"
trainer, metrics = train_fungi_classifier(data_dir)
print(f"Evaluation metrics: {metrics}")
Results Analysis
Training Progression

Key observations:
Steady improvement: Accuracy improves from 90% to 92% across epochs
Good generalization: Small gap between training and validation loss
High precision on Fungi: 92% precision means very few false positives
Interpreting the Color Histograms
No Fungi samples:
Often concentrated in lower pixel ranges (darker overall)
More uniform distributions across channels
Red channel sometimes shows peaks at low values
Fungi samples:
Broader distributions due to fluorescence
Green/Blue channels often dominant (fluorescence marker)
Multiple peaks indicate varied tissue structures
Conclusion
What We Achieved
Built an end-to-end image classification pipeline
Achieved over 90% accuracy on fungi detection
Maintained balanced performance across both classes
Created reusable visualization and analysis tools
Full code is available at:
Transform Your AI Workflows with Codersarts
Whether you're building intelligent systems with AI implementing RAG for smart information retrieval, or developing robust multi-agent architectures, the experts at Codersarts are here to support your vision. From academic prototypes to enterprise-grade solutions, we provide:
Custom RAG Implementation: Build retrieval-augmented generation systems tailored to your domain
AI-Based Agent Systems: Design and deploy modular, coordinated AI agents
End-to-End AI Development: From setup and orchestration to deployment and optimization
Do not let architectural complexity or tooling challenges slow down your progress. Partner with Codersarts and bring your next-generation AI systems to life.
Ready to get started? Visit Codersarts.com or connect with our team to discuss your AI-based project.The future of modular, intelligent automation is here – let’s build it together!




Comments