import torchvision import torch.nn as nn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights, FasterRCNN_ResNet50_FPN_V2_Weights, FasterRCNN_MobileNet_V3_Large_FPN_Weights def FasterRCNN_V1(num_classes, minsize, maxsize, pretrain=False): # load Faster RCNN pre-trained model if pretrain == True : model = torchvision.models.detection.fasterrcnn_resnet50_fpn(min_size=minsize, max_size =maxsize , weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) elif pretrain == False : model = torchvision.models.detection.fasterrcnn_resnet50_fpn(min_size=minsize, max_size =maxsize) # get the number of input features in_features = model.roi_heads.box_predictor.cls_score.in_features # define a new head for the detector with required number of classes model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model def FasterRCNN_V2(num_classes, minsize, maxsize, pretrain=False): # load Faster RCNN pre-trained model if pretrain == True : model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(min_size=minsize, max_size = maxsize, weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT) elif pretrain == False : model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(min_size=minsize, max_size = maxsize) # get the number of input features in_features = model.roi_heads.box_predictor.cls_score.in_features # define a new head for the detector with required number of classes model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model def FasterRCNN_V3(num_classes, minsize, maxsize, pretrain=False): # load Faster RCNN pre-trained model if pretrain == True : model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(min_size=minsize, max_size = maxsize, weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) elif pretrain == False : model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(min_size=minsize, max_size = maxsize) # get the number of input features in_features = model.roi_heads.box_predictor.cls_score.in_features # define a new head for the detector with required number of classes model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model