trident
trident
Table Of Contents
trident
Table Of Contents

Source code for trident.models.pytorch_mtcnn

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import inspect
import itertools
import math
import os
import uuid
from collections import *
from collections import deque
from copy import copy, deepcopy
from functools import partial
from itertools import repeat

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from matplotlib.collections import PolyCollection
from torch._six import container_abcs
from torch.nn import init
from torch.nn.parameter import Parameter

from trident.backend.common import *
from trident.backend.pytorch_backend import *
from trident.backend.pytorch_backend import to_numpy, to_tensor, Layer, Sequential, Combine
from trident.backend.pytorch_ops import meshgrid
from trident.data.bbox_common import clip_boxes_to_image, nms
from trident.data.image_common import *
from trident.data.utils import download_model_from_google_drive
from trident.layers.pytorch_activations import get_activation, Identity, PRelu
from trident.layers.pytorch_blocks import *
from trident.layers.pytorch_layers import *
from trident.layers.pytorch_normalizations import get_normalization
from trident.layers.pytorch_pooling import *
from trident.optims.pytorch_trainer import *
from trident.optims.pytorch_trainer import ImageDetectionModel

__all__ = ['Pnet','Rnet','Onet','Mtcnn']

_session = get_session()
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_epsilon=_session.epsilon
_trident_dir=_session.trident_dir


dirname = os.path.join(_trident_dir, 'models')
if not os.path.exists(dirname):
    try:
        os.makedirs(dirname)
    except OSError:
        # Except permission denied and potential race conditions
        # in multi-threaded environments.
        pass


def p_net():
    return Sequential(
    Conv2d((3,3),10,strides=1,auto_pad=False,use_bias=True,name='conv1'),
    PRelu(),
    MaxPool2d((2,2),strides=2,auto_pad=False),
    Conv2d((3, 3), 16, strides=1, auto_pad=False,use_bias=True,name='conv2'),
    PRelu(),
    Conv2d((3,3),32,strides=1,auto_pad=False,use_bias=True,name='conv3'),
    PRelu(),
    Combine(
        Conv2d((1,1),1,strides=1,auto_pad=False,use_bias=True,activation='sigmoid',name='conv4_1'),
        Conv2d((1,1),4,strides=1,auto_pad=False,use_bias=True,name='conv4_2'),
        Conv2d((1,1),10,strides=1,auto_pad=False,use_bias=True,name='conv4_3')),name='pnet')



def r_net():
    return Sequential(
    Conv2d((3,3),28,strides=1,auto_pad=False,use_bias=True,name='conv1'),
    PRelu(),
    MaxPool2d((3,3),strides=2,auto_pad=False),
    Conv2d((3, 3), 48, strides=1, auto_pad=False,use_bias=True,name='conv2'),
    PRelu(),
    MaxPool2d((3,3),strides=2,auto_pad=False),
    Conv2d((2,2),64,strides=1,auto_pad=False,use_bias=True,name='conv3'),
    PRelu(),
    Flatten(),
    Dense(128,activation=None,use_bias=True,name='conv4'),
    PRelu(),
    Combine(
        Dense(1,activation='sigmoid',use_bias=True,name='conv5_1'),
        Dense(4,activation=None,use_bias=True,name='conv5_2'),
        Dense(10,activation=None,use_bias=True,name='conv5_3'))
    ,name='rnet')



def o_net():
    return Sequential(
    Conv2d((3,3),32,strides=1,auto_pad=False,use_bias=True,name='conv1'),
    PRelu(),
    MaxPool2d((3,3),strides=2,auto_pad=False),
    Conv2d((3, 3), 64, strides=1, auto_pad=False,use_bias=True,name='conv2'),
    PRelu(),
    MaxPool2d((3,3),strides=2,auto_pad=False),
    Conv2d((3,3),64,strides=1,auto_pad=False,use_bias=True,name='conv3'),
    PRelu(),
    MaxPool2d((2, 2), strides=2,auto_pad=False),
    Conv2d((2, 2), 128, strides=1, auto_pad=False,use_bias=True,name='conv4'),
    PRelu(),
    Flatten(),
    Dense(256,activation=None,use_bias=True,name='conv5'),
    PRelu(),
    Combine(
        Dense(1,activation='sigmoid',use_bias=True,name='conv6_1'),
        Dense(4,activation=None,use_bias=True,name='conv6_2'),
        Dense(10,activation=None,use_bias=True,name='conv6_3')),name='onet')



[docs]def Pnet(pretrained=True, input_shape=(3,12,12), **kwargs): if input_shape is not None and len(input_shape)==3: input_shape=tuple(input_shape) else: input_shape=(3,12,12) pnet =ImageDetectionModel(input_shape=(3,12,12),output=p_net()) pnet.preprocess_flow = [normalize(0, 255), image_backend_adaption] if pretrained==True: download_model_from_google_drive('1w9ahipO8D9U1dAXMc2BewuL0UqIBYWSX',dirname,'pnet.pth') recovery_model=torch.load(os.path.join(dirname,'pnet.pth')) recovery_model.to(_device) pnet.model=recovery_model return pnet
[docs]def Rnet(pretrained=True, input_shape=(3,24,24), **kwargs): if input_shape is not None and len(input_shape)==3: input_shape=tuple(input_shape) else: input_shape=(3,24,24) rnet =ImageDetectionModel(input_shape=(3,24,24),output=r_net()) rnet.preprocess_flow = [normalize(0, 255), image_backend_adaption] if pretrained==True: download_model_from_google_drive('1CH7z133_KrcWMx9zXAblMCV8luiQ3wph',dirname,'rnet.pth') recovery_model=torch.load(os.path.join(dirname,'rnet.pth')) recovery_model.to(_device) rnet.model=recovery_model return rnet
[docs]def Onet(pretrained=True, input_shape=(3,48,48), **kwargs): if input_shape is not None and len(input_shape)==3: input_shape=tuple(input_shape) else: input_shape=(3,48,48) onet =ImageDetectionModel(input_shape=(3,48,48),output=o_net()) onet.preprocess_flow = [normalize(0, 255), image_backend_adaption] if pretrained==True: download_model_from_google_drive('1a1dAlSzJOAfIz77Ic38JMQJYWDG_b7-_',dirname,'onet.pth') recovery_model=torch.load(os.path.join(dirname,'onet.pth')) recovery_model.to(_device) onet.model=recovery_model return onet
class DetectorHead(Layer): def __init__(self, cellsize=12,threshould=0.5, min_size=10,**kwargs): super(DetectorHead, self).__init__(**kwargs) self.cellsize=cellsize self.threshould=threshould self.min_size=min_size self._built =True def forward(self, input,**kwargs): boxprobs,boxregs,landscape=input boxprobs=boxprobs[0] height,width=boxprobs.shape[1:] if boxprobs.size(0)==2: boxprobs=boxprobs[1:,:,:] strides=2 boxregs=boxregs[0] input_shape=boxprobs.size() grid=meshgrid(boxprobs.size(1),boxprobs.size(2)) grid=grid.view(2,-1) score = boxprobs[0] y,x = torch.where(score>= self.threshould) boxregs = boxregs.permute(1,2,0) score = score[(y,x )] reg=boxregs[(y,x )].transpose(1,0) bb = torch.stack([x,y], dim=0) q1 = (strides * bb + 1) q2 =(strides * bb +self.cellsize - 1 + 1) w = q2[0, :] - q1[0, :] + 1 h = q2[1, :] - q1[1, :] + 1 b1 = q1[0, :] + reg[0, :] * w b2 = q1[1, :] + reg[1, :] * h b3 =q2[0, :] + reg[2, :] * w b4 =q2[1, :] + reg[3, :] * h boxs=torch.stack([b1,b2,b3,b4,score],dim=-1) #keep =torchvision.ops.boxes.remove_small_boxes(boxs[:,:4],min_size=self.min_size) #boxs=boxs[keep] #print('total {0} boxes cutoff={1} '.format(len(x), cutoff)) if boxs is None or len(boxs.size()) == 0: return None elif len(boxs.size())==1: boxs=boxs.unsqueeze(0) return boxs def remove_useless_boxes(boxes,image_size=None,min_size=5): height, width = image_size if image_size is not None else (None,None) x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] area = (x2 - x1 + 1) * (y2 - y1 + 1) boxes=boxes[area>min_size*min_size] x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] greater0=x1.gt(0).float() * x2.gt(0).float() * y1.gt(0).float() * y1.gt(0).float() boxes=boxes[greater0>0] x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] w=(x2 - x1 ) boxes=boxes[w>1] x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] h=(y2 - y1) boxes = boxes[h > 1] return boxes def calibrate_box(bboxes, offsets): """ Transform bounding boxes to be more like true bounding boxes. 'offsets' is one of the outputs of the nets. """ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] w = x2 - x1 + 1.0 h = y2 - y1 + 1.0 # w [w_len, 1] w = torch.unsqueeze(w, 1) # h [h_len, 1] h = torch.unsqueeze(h, 1) translation = torch.cat([w, h, w, h],-1) * offsets bboxes[:, 0:4] = bboxes[:, 0:4] + translation return bboxes
[docs]class Mtcnn(ImageDetectionModel): def __init__(self, pretrained=True, min_size=10, **kwargs): pnet = ImageDetectionModel(input_shape=(3, 12, 12), output=p_net()).model self.rnet = ImageDetectionModel(input_shape=(3, 24, 24), output=r_net()).model self.onet = ImageDetectionModel(input_shape=(3, 48, 48), output=o_net()).model if pretrained == True: pnet = Pnet().model self.rnet = Rnet().model self.onet = Onet().model self.min_size = min_size self.signature=get_signature(self.model.forward) super(Mtcnn, self).__init__(input_shape=(3,224,224),output=pnet) self.pnet=pnet self.model=pnet self.preprocess_flow =[normalize(0,255)] self.nms_threshould = [0.9, 0.9, 0.3] self.detection_threshould = [0.5, 0.6, 0.9] pnet.add_module('pnet_detector', DetectorHead(cellsize=12, threshould=0.5, min_size=self.min_size))
[docs] def get_image_pyrimid(self,img,min_size=None,factor= 0.709): if min_size is None: min_size=self.min_size min_face_area = (min_size, min_size) h = img.shape[0] w = img.shape[1] minl = np.amin([h, w]) m = 12.0 / min_size minl = minl * m # create scale pyramid scales = [] images = [] factor_count = 0 while minl >= 12: scales += [m * np.power(factor, factor_count)] scaled_img = rescale(scales[-1])(img.copy()) if img is not None: for func in self.preprocess_flow: if inspect.isfunction(func): scaled_img=func(scaled_img) images.append(image_backend_adaption(scaled_img)) minl = minl * factor factor_count += 1 return images, scales
#adjust bbox like square
[docs] def rerec(self,bboxA,img_shape): """Convert bboxA to square.""" bboxA=to_numpy(bboxA) h = bboxA[:, 3] - bboxA[:, 1] w = bboxA[:, 2] - bboxA[:, 0] max_len = np.maximum(w, h) bboxA[:, 0] = bboxA[:, 0] -0.5*(max_len-w) bboxA[:, 1] = bboxA[:, 1] -0.5*(max_len-h) bboxA[:, 2] = bboxA[:, 0]+max_len bboxA[:, 3] =bboxA[:, 1]+max_len return to_tensor(bboxA)
[docs] def infer_single_image(self,img,**kwargs): if self.model.built: self.model.to(self.device) self.model.eval() img=image2array(img) if img.shape[-1]==4: img=img[:,:,:3] imgs,scales=self.get_image_pyrimid(img) boxes_list=[] for i in range(len(scales)): scaled_img=imgs[i] inp = to_tensor(np.expand_dims(scaled_img, 0)).to(torch.device("cuda" if self.pnet.weights[0].data.is_cuda else "cpu")).to(self.pnet.weights[0].data.dtype) boxes=self.pnet(inp) if boxes is not None and len(boxes)>0: scale=scales[i] box=boxes[:,:4]/scale score=boxes[:,4:] boxes = torch.cat([box.round_(), score], dim=1) if len(boxes) > 0: boxes_list.append(boxes) ####################################### #########pnet finish ####################################### if len(boxes_list) > 0: boxes=to_tensor(torch.cat(boxes_list, dim=0)) #print('total {0} boxes in pnet in all scale '.format(len(boxes))) boxes=clip_boxes_to_image(boxes,(img.shape[0],img.shape[1])) boxes =nms(boxes, threshold=self.detection_threshould[0]) print('pnet:{0} boxes '.format(len(boxes))) #print('total {0} boxes after nms '.format(len(boxes))) #score = to_numpy(boxes[:, 4]).reshape(-1) if boxes is not None: #prepare rnet input boxes= self.rerec(boxes, img.shape) new_arr = np.zeros((boxes.shape[0], 3, 24, 24)) for k in range(boxes.shape[0]): box = boxes[k] crop_img = img.copy()[int(box[1]):int(box[3]), int(box[0]):int(box[2]), :] if crop_img.shape[0] > 0 and crop_img.shape[1] > 0: new_arr[k] = resize((24, 24))(crop_img / 255.0).transpose([2, 0, 1]) # else: # print(box) new_arr = to_tensor(new_arr) r_output1_list = [] r_output2_list = [] r_output3_list = [] if len(new_arr) > 16: for i in range(len(new_arr) // 16 + 1): if i * 16 < len(new_arr): r_out1, r_out2, r_out3 = self.rnet(new_arr[i * 16:(i + 1) * 16, :, :, :]) r_output1_list.append(r_out1) r_output2_list.append(r_out2) r_output3_list.append(r_out3) r_out1 = torch.cat(r_output1_list, dim=0) r_out2 = torch.cat(r_output2_list, dim=0) r_out3 = torch.cat(r_output3_list, dim=0) else: r_out1, r_out2, r_out3 = self.rnet(new_arr) probs = to_numpy(r_out1) keep = np.where(probs[:, 0] > self.detection_threshould[1])[0] r_out1=r_out1[keep] boxes = boxes[keep] boxes[:, 4] = r_out1[:, 0] r_out2 = r_out2[keep] boxes=calibrate_box(boxes,r_out2) ####################################### #########rnet finish ####################################### boxes=nms(boxes, threshold=self.detection_threshould[1],image_size=(img.shape[0],img.shape[1]),min_size=self.min_size) print('rnet:{0} boxes '.format(len(boxes))) #print('total {0} boxes after nms '.format(len(boxes))) boxes = clip_boxes_to_image(boxes, (img.shape[0], img.shape[1])) boxes=self.rerec(boxes,img.shape) new_arr=np.zeros((boxes.shape[0],3,48,48)) for k in range(boxes.shape[0]): box=boxes[k] crop_img=img.copy()[int(box[1]):int(box[3]),int(box[0]):int(box[2]),:] if crop_img.shape[0]>0 and crop_img.shape[1]>0: new_arr[k]=resize((48,48))(crop_img/255.0).transpose([2,0,1]) # else: # print(box) new_arr=to_tensor(new_arr) o_out1, o_out2,o_out3 = self.onet(new_arr) probs = to_numpy(o_out1) keep = np.where(probs[:, 0] > self.detection_threshould[2])[0] o_out1 = o_out1[keep] boxes = boxes[keep] boxes[:, 4] = o_out1[:, 0] o_out2 = o_out2[keep] o_out3=o_out3[keep] boxes = calibrate_box(boxes, o_out2) landmarks_x = boxes[:, 0:1] + o_out3[:, 0::2] * (boxes[:, 2:3] - boxes[:, 0:1]+1) landmarks_y = boxes[:, 1:2] + o_out3[:, 1::2] * (boxes[:, 3:4] - boxes[:, 1:2]+1) boxes=torch.cat([boxes,landmarks_x,landmarks_y],dim=-1) ####################################### #########onet finish ####################################### boxes=nms(boxes, threshold=self.detection_threshould[2],image_size=(img.shape[0],img.shape[1]),min_size=self.min_size) print('onet:{0} boxes '.format(len(boxes))) return boxes else: return None #idx=int(np.argmax(result,-1)[0]) else: raise ValueError('the model is not built yet.')
[docs] def generate_bboxes(self,*outputs,threshould=0.5,scale=1): raise NotImplementedError
[docs] def nms(self,bboxes): raise NotImplementedError