Skip to content

Datasets Module

ClassificationImageFolder

Bases: ImageFolder

A PyTorch Dataset for loading images from a specified directory. Each item in the dataset is a tuple containing the image data, the image's path, and the original size of the image.

Source code in PytorchWildlife/data/datasets.py
class ClassificationImageFolder(ImageFolder):
    """
    A PyTorch Dataset for loading images from a specified directory.
    Each item in the dataset is a tuple containing the image data, 
    the image's path, and the original size of the image.
    """

    def __init__(self, image_dir, transform=None):
        """
        Initializes the dataset.

        Parameters:
            image_dir (str): Path to the directory containing the images.
            transform (callable, optional): Optional transform to be applied on the image.
        """
        super(ClassificationImageFolder, self).__init__(image_dir, transform)

    def __getitem__(self, idx) -> tuple:
        """
        Retrieves an image from the dataset.

        Parameters:
            idx (int): Index of the image to retrieve.

        Returns:
            tuple: Contains the image data, the image's path, and its original size.
        """
        # Get image filename and path
        img_path = self.images[idx]

        # Load and convert image to RGB
        img = Image.open(img_path).convert("RGB")

        # Apply transformation if specified
        if self.transform:
            img = self.transform(img)

        return img, img_path

__getitem__(idx)

Retrieves an image from the dataset.

Parameters:

Name Type Description Default
idx int

Index of the image to retrieve.

required

Returns:

Name Type Description
tuple tuple

Contains the image data, the image's path, and its original size.

Source code in PytorchWildlife/data/datasets.py
def __getitem__(self, idx) -> tuple:
    """
    Retrieves an image from the dataset.

    Parameters:
        idx (int): Index of the image to retrieve.

    Returns:
        tuple: Contains the image data, the image's path, and its original size.
    """
    # Get image filename and path
    img_path = self.images[idx]

    # Load and convert image to RGB
    img = Image.open(img_path).convert("RGB")

    # Apply transformation if specified
    if self.transform:
        img = self.transform(img)

    return img, img_path

__init__(image_dir, transform=None)

Initializes the dataset.

Parameters:

Name Type Description Default
image_dir str

Path to the directory containing the images.

required
transform callable

Optional transform to be applied on the image.

None
Source code in PytorchWildlife/data/datasets.py
def __init__(self, image_dir, transform=None):
    """
    Initializes the dataset.

    Parameters:
        image_dir (str): Path to the directory containing the images.
        transform (callable, optional): Optional transform to be applied on the image.
    """
    super(ClassificationImageFolder, self).__init__(image_dir, transform)

DetectionCrops

Bases: Dataset

Source code in PytorchWildlife/data/datasets.py
class DetectionCrops(Dataset):

    def __init__(self, detection_results, transform=None, path_head=None, animal_cls_id=0):

        self.detection_results = detection_results
        self.transform = transform
        self.path_head = path_head
        self.animal_cls_id = animal_cls_id # This determines which detection class id represents animals.
        self.img_ids = []
        self.xyxys = []

        self.load_detection_results()

    def load_detection_results(self):
        for det in self.detection_results:
            for xyxy, det_id in zip(det["detections"].xyxy, det["detections"].class_id):
                # Only run recognition on animal detections
                if det_id == self.animal_cls_id:
                    self.img_ids.append(det["img_id"])
                    self.xyxys.append(xyxy)

    def __getitem__(self, idx) -> tuple:
        """
        Retrieves an image from the dataset.

        Parameters:
            idx (int): Index of the image to retrieve.

        Returns:
            tuple: Contains the image data and the image's path.
        """

        # Get image path and corresponding bbox xyxy for cropping
        img_id = self.img_ids[idx]
        xyxy = self.xyxys[idx]

        img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id

        # Load and crop image with supervision
        img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")),
                            xyxy=xyxy)

        # Apply transformation if specified
        if self.transform:
            img = self.transform(Image.fromarray(img))

        return img, img_path

    def __len__(self) -> int:
        return len(self.img_ids)

__getitem__(idx)

Retrieves an image from the dataset.

Parameters:

Name Type Description Default
idx int

Index of the image to retrieve.

required

Returns:

Name Type Description
tuple tuple

Contains the image data and the image's path.

Source code in PytorchWildlife/data/datasets.py
def __getitem__(self, idx) -> tuple:
    """
    Retrieves an image from the dataset.

    Parameters:
        idx (int): Index of the image to retrieve.

    Returns:
        tuple: Contains the image data and the image's path.
    """

    # Get image path and corresponding bbox xyxy for cropping
    img_id = self.img_ids[idx]
    xyxy = self.xyxys[idx]

    img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id

    # Load and crop image with supervision
    img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")),
                        xyxy=xyxy)

    # Apply transformation if specified
    if self.transform:
        img = self.transform(Image.fromarray(img))

    return img, img_path

DetectionImageFolder

Bases: ImageFolder

A PyTorch Dataset for loading images from a specified directory. Each item in the dataset is a tuple containing the image data, the image's path, and the original size of the image.

Source code in PytorchWildlife/data/datasets.py
class DetectionImageFolder(ImageFolder):
    """
    A PyTorch Dataset for loading images from a specified directory.
    Each item in the dataset is a tuple containing the image data, 
    the image's path, and the original size of the image.
    """

    def __init__(self, image_dir, transform=None):
        """
        Initializes the dataset.

        Parameters:
            image_dir (str): Path to the directory containing the images.
            transform (callable, optional): Optional transform to be applied on the image.
        """
        super(DetectionImageFolder, self).__init__(image_dir, transform)

    def __getitem__(self, idx) -> tuple:
        """
        Retrieves an image from the dataset.

        Parameters:
            idx (int): Index of the image to retrieve.

        Returns:
            tuple: Contains the image data, the image's path, and its original size.
        """
        # Get image filename and path
        img_path = self.images[idx]

        # Load and convert image to RGB
        img = Image.open(img_path).convert("RGB")
        img_size_ori = img.size[::-1]

        # Apply transformation if specified
        if self.transform:
            img = self.transform(img)

        return img, img_path, torch.tensor(img_size_ori)

__getitem__(idx)

Retrieves an image from the dataset.

Parameters:

Name Type Description Default
idx int

Index of the image to retrieve.

required

Returns:

Name Type Description
tuple tuple

Contains the image data, the image's path, and its original size.

Source code in PytorchWildlife/data/datasets.py
def __getitem__(self, idx) -> tuple:
    """
    Retrieves an image from the dataset.

    Parameters:
        idx (int): Index of the image to retrieve.

    Returns:
        tuple: Contains the image data, the image's path, and its original size.
    """
    # Get image filename and path
    img_path = self.images[idx]

    # Load and convert image to RGB
    img = Image.open(img_path).convert("RGB")
    img_size_ori = img.size[::-1]

    # Apply transformation if specified
    if self.transform:
        img = self.transform(img)

    return img, img_path, torch.tensor(img_size_ori)

__init__(image_dir, transform=None)

Initializes the dataset.

Parameters:

Name Type Description Default
image_dir str

Path to the directory containing the images.

required
transform callable

Optional transform to be applied on the image.

None
Source code in PytorchWildlife/data/datasets.py
def __init__(self, image_dir, transform=None):
    """
    Initializes the dataset.

    Parameters:
        image_dir (str): Path to the directory containing the images.
        transform (callable, optional): Optional transform to be applied on the image.
    """
    super(DetectionImageFolder, self).__init__(image_dir, transform)

ImageFolder

Bases: Dataset

A PyTorch Dataset for loading images from a specified directory. Each item in the dataset is a tuple containing the image data, the image's path, and the original size of the image.

Source code in PytorchWildlife/data/datasets.py
class ImageFolder(Dataset):
    """
    A PyTorch Dataset for loading images from a specified directory.
    Each item in the dataset is a tuple containing the image data, 
    the image's path, and the original size of the image.
    """

    def __init__(self, image_dir, transform=None):
        """
        Initializes the dataset.

        Parameters:
            image_dir (str): Path to the directory containing the images.
            transform (callable, optional): Optional transform to be applied on the image.
        """
        super(ImageFolder, self).__init__()
        self.image_dir = image_dir
        self.transform = transform
        self.images = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename

    def __getitem__(self, idx) -> tuple:
        """
        Retrieves an image from the dataset.

        Parameters:
            idx (int): Index of the image to retrieve.

        Returns:
            tuple: Contains the image data, the image's path, and its original size.
        """
        pass

    def __len__(self) -> int:
        """
        Returns the total number of images in the dataset.

        Returns:
            int: Total number of images.
        """
        return len(self.images)

__getitem__(idx)

Retrieves an image from the dataset.

Parameters:

Name Type Description Default
idx int

Index of the image to retrieve.

required

Returns:

Name Type Description
tuple tuple

Contains the image data, the image's path, and its original size.

Source code in PytorchWildlife/data/datasets.py
def __getitem__(self, idx) -> tuple:
    """
    Retrieves an image from the dataset.

    Parameters:
        idx (int): Index of the image to retrieve.

    Returns:
        tuple: Contains the image data, the image's path, and its original size.
    """
    pass

__init__(image_dir, transform=None)

Initializes the dataset.

Parameters:

Name Type Description Default
image_dir str

Path to the directory containing the images.

required
transform callable

Optional transform to be applied on the image.

None
Source code in PytorchWildlife/data/datasets.py
def __init__(self, image_dir, transform=None):
    """
    Initializes the dataset.

    Parameters:
        image_dir (str): Path to the directory containing the images.
        transform (callable, optional): Optional transform to be applied on the image.
    """
    super(ImageFolder, self).__init__()
    self.image_dir = image_dir
    self.transform = transform
    self.images = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename

__len__()

Returns the total number of images in the dataset.

Returns:

Name Type Description
int int

Total number of images.

Source code in PytorchWildlife/data/datasets.py
def __len__(self) -> int:
    """
    Returns the total number of images in the dataset.

    Returns:
        int: Total number of images.
    """
    return len(self.images)

has_file_allowed_extension(filename, extensions)

Checks if a file is an allowed extension.

Source code in PytorchWildlife/data/datasets.py
def has_file_allowed_extension(filename: str, extensions: tuple) -> bool:  
    """Checks if a file is an allowed extension."""  
    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))

is_image_file(filename)

Checks if a file is an allowed image extension.

Source code in PytorchWildlife/data/datasets.py
def is_image_file(filename: str) -> bool:  
    """Checks if a file is an allowed image extension."""  
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)