Serengeti¶
AI4GSnapshotSerengeti
¶
Bases: PlainResNetInference
Snapshot Serengeti Animal Classifier that inherits from PlainResNetInference. This classifier is specialized for recognizing 9 different animals and has 1 'other' class.
Source code in PytorchWildlife/models/classification/resnet_base/serengeti.py
__init__(weights=None, device='cpu', pretrained=True)
¶
Initialize the Amazon animal Classifier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights
|
str
|
Path to the model weights. Defaults to None. |
None
|
device
|
str
|
Device for model inference. Defaults to "cpu". |
'cpu'
|
pretrained
|
bool
|
Whether to use pretrained weights. Defaults to True. |
True
|
Source code in PytorchWildlife/models/classification/resnet_base/serengeti.py
results_generation(logits, img_ids, id_strip=None)
¶
Generate results for classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits
|
Tensor
|
Output tensor from the model. |
required |
img_ids
|
str
|
Image identifier. |
required |
id_strip
|
str
|
stiping string for better image id saving. |
None
|
Returns:
Name | Type | Description |
---|---|---|
dict |
list[dict]
|
Dictionary containing image ID, prediction, and confidence score. |