Object Detection to Classification
  • 21 Dec 2022
  • 2 Minutes to read
  • Dark
    Light
  • PDF

Object Detection to Classification

  • Dark
    Light
  • PDF

Article summary

You may want to check the accuracy of an Object Detection model by its precision/recall metrics at the image level. A custom post-processing that convert the object detection output to classification values can help you summarize if the image detected any object.

Folder Structure

.
├── custom
│   ├── __init__.py
│   ├── detection_to_classification.py
├── train.yaml
└── transforms.yaml

ObjectDetectionToClassification Class

Content of custom/detection_to_classification.py.

from landinglens.model_iteration.sdk import BaseTransform, DataItem

class ObjectDetectionToClassification(BaseTransform):
      """Transforms a object detection output into a classification output."""
      def __init__(self, boxes_threshold: int = 1, **params):
            """
            Parameters
        ----------
       """

assert isinstance(
           boxes_threshold, int
), f"Score position is not an integer. Got {boxes_threshold}"
assert (
boxes_threshold > 0
), f"Score position is not positive. Got {boxes_threshold}"
        self.boxes_threshold = boxes_threshold

    def __call__(self, inputs: DataItem) -> DataItem:
"""Return a new DataItem with transformed attributes. DataItem has following
        attributes:
        image - input image.
        label, score - input label and its score.
        mask_scores, mask_labels - segmentation mask probabilities and classes.
        bboxes, bboxes_labels, bboxes_scores - object detection bounding boxes.
        user_data - any additional data that you want to store for subsequent transform.

        Returns
        -------
 A name
d tuple class DataItem with the modified attributes.
        """
        bboxes = inputs.bboxes

        if bboxes is None:
             return DataItem(image=inputs.image, label=0)

        if bboxes.shape[0] >= self.boxes_threshold:
              label = 1 #ng
        else:
            label = 0 # ok

        return DataItem(
            image=inputs.image,
            label=label,
            bboxes=inputs.bboxes,
            bboxes_labels=inputs.bboxes_labels,
            bboxes_scores=inputs.bboxes_scores,
        ) 

Use ObjectDetectionToClassification in train.yaml

dataset:
     test_split_key: dev
     train_split_key: train
     val_split_key: dev
eval:
     postprocessing:
          output_type: classification
         iou_threshold
: 0.25
          transforms
:
- CustomTransform:
                 params:
                   boxes_threshold: 1
                 transform: custom.detection_to_classification.ObjectDetectionToClassification
loss:
     classification:
        RetinaNetFocal:
           gamma: 2
    regression:
        RetinaNetSmoothL1:
           sigma: 3
model:
     avi:
        RetinaNetOD:
             backbone: ResNet34
             backbone_weights: imagenet
             class_specific_filter: false
             input_shape:
                - 256
                - 256
                - 3
             nms_threshold: 0.1
             output_depth: 5
             score_threshold: 0.3
train:
    batch_size: 8
    epochs: 170
    learning_rate: 0.0001
    validation_run_freq: 1       



Was this article helpful?