Timm Base¶
model class for loading the DFNE classifier.
TIMM_BaseClassifierInference
¶
Bases: BaseClassifierInference
Base detector class for dinov2 classifier. This class provides utility methods for loading the model, performing single and batch image classifications, and formatting results. Make sure the appropriate file for the model weights has been downloaded to the "models" folder before running DFNE.
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
|
|
__init__(weights=None, device='cpu', url=None, transform=None, weights_key='model_state_dict', weights_prefix='')
¶
Initialize the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights
|
str
|
Path to the model weights. Defaults to None. |
None
|
device
|
str
|
Device for model inference. Defaults to "cpu". |
'cpu'
|
url
|
str
|
URL to fetch the model weights. Defaults to None. |
None
|
weights_key
|
str
|
Key to fetch the model weights. Defaults to None. |
'model_state_dict'
|
weights_prefix
|
str
|
prefix of model weight keys. Defaults to None. |
''
|
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
batch_image_classification(data_path=None, det_results=None, id_strip=None, batch_size=32, num_workers=0, **kwargs)
¶
Perform classification on a batch of images.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_path
|
str
|
Path containing all images for inference. Defaults to None. |
None
|
det_results
|
dict
|
Dirct outputs from detectors. Defaults to None. |
None
|
id_strip
|
str
|
Whether to strip stings in id. Defaults to None. |
None
|
batch_size
|
int
|
Batch size for inference. Defaults to 32. |
32
|
num_workers
|
int
|
Number of workers for dataloader. Defaults to 0. |
0
|
Returns:
Type | Description |
---|---|
dict
|
Classification results. |
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
results_generation(logits, img_ids, id_strip=None)
¶
Generate results for classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits
|
Tensor
|
Output tensor from the model. |
required |
img_ids
|
list[str]
|
List of image identifiers. |
required |
id_strip
|
str
|
Stripping string for better image ID saving. |
None
|
Returns:
Type | Description |
---|---|
list[dict]
|
list[dict]: List of dictionaries containing image ID, prediction, and confidence score. |
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
single_image_classification(img, img_id=None, id_strip=None)
¶
Perform classification on a single image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img
|
str or ndarray
|
Image path or ndarray of images. |
required |
img_id
|
str
|
Image path or identifier. |
None
|
id_strip
|
str
|
Whether to strip stings in id. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
dict
|
Classification results. |