Timm Base¶
model class for loading the DFNE classifier.
TIMM_BaseClassifierInference
¶
Bases: BaseClassifierInference
Base detector class for dinov2 classifier. This class provides utility methods for loading the model, performing single and batch image classifications, and formatting results. Make sure the appropriate file for the model weights has been downloaded to the "models" folder before running DFNE.
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
|
__init__(weights=None, device='cpu', url=None, transform=None, weights_key='model_state_dict', weights_prefix='')
¶
Initialize the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights
|
str
|
Path to the model weights. Defaults to None. |
None
|
device
|
str
|
Device for model inference. Defaults to "cpu". |
'cpu'
|
url
|
str
|
URL to fetch the model weights. Defaults to None. |
None
|
weights_key
|
str
|
Key to fetch the model weights. Defaults to None. |
'model_state_dict'
|
weights_prefix
|
str
|
prefix of model weight keys. Defaults to None. |
''
|
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
batch_image_classification(data_path=None, det_results=None, id_strip=None, batch_size=32, num_workers=0, **kwargs)
¶
Perform classification on a batch of images.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_path
|
str
|
Path containing all images for inference. Defaults to None. |
None
|
det_results
|
dict
|
Dirct outputs from detectors. Defaults to None. |
None
|
id_strip
|
str
|
Whether to strip stings in id. Defaults to None. |
None
|
batch_size
|
int
|
Batch size for inference. Defaults to 32. |
32
|
num_workers
|
int
|
Number of workers for dataloader. Defaults to 0. |
0
|
Returns:
Type | Description |
---|---|
dict
|
Classification results. |
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
results_generation(logits, img_ids, id_strip=None)
¶
Generate results for classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits
|
Tensor
|
Output tensor from the model. |
required |
img_ids
|
list[str]
|
List of image identifiers. |
required |
id_strip
|
str
|
Stripping string for better image ID saving. |
None
|
Returns:
Type | Description |
---|---|
list[dict]
|
list[dict]: List of dictionaries containing image ID, prediction, and confidence score. |
Source code in PytorchWildlife/models/classification/timm_base/base_classifier.py
single_image_classification(img, img_id=None, id_strip=None)
¶
Perform classification on a single image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img
|
str or ndarray
|
Image path or ndarray of images. |
required |
img_id
|
str
|
Image path or identifier. |
None
|
id_strip
|
str
|
Whether to strip stings in id. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
dict
|
Classification results. |