trident
trident
Table Of Contents
trident
Table Of Contents

Source code for trident.models.pytorch_resnet

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import inspect
import math
import os
import uuid
from collections import *
from collections import deque
from copy import copy, deepcopy
from functools import partial
from itertools import repeat

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._six import container_abcs
from torch.nn import init
from torch.nn.parameter import Parameter

from trident.backend.common import *
from trident.backend.pytorch_backend import to_numpy, to_tensor, Layer, Sequential
from trident.data.image_common import *
from trident.data.utils import download_model_from_google_drive
from trident.layers.pytorch_activations import get_activation, Identity
from trident.layers.pytorch_blocks import *
from trident.layers.pytorch_layers import *
from trident.layers.pytorch_normalizations import get_normalization
from trident.layers.pytorch_pooling import *
from trident.optims.pytorch_trainer import *

__all__ = ['basic_block','bottleneck', 'ResNet','ResNet50','ResNet101','ResNet152']

_session = get_session()
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_epsilon=_session.epsilon
_trident_dir=_session.trident_dir


dirname = os.path.join(_trident_dir, 'models')
if not os.path.exists(dirname):
    try:
        os.makedirs(dirname)
    except OSError:
        # Except permission denied and potential race conditions
        # in multi-threaded environments.
        pass

model_urls = {
    'resnet50': '1dYlgpFtqi87KDG54_db4ALWKLARxCWMS',
    'resnet101': '17moUOsGynsWALLHyv3yprHWbbDMrdiOP',
    'resnet152': '1BIaHb7_qunUVvt4TDAwonSKI2jYg4Ybj',
}

[docs]def basic_block(num_filters=64,base_width=64,strides=1,expansion = 4,conv_shortcut=False,use_bias=False,name=''): shortcut = Identity() if strides>1 or conv_shortcut is True: shortcut =Conv2d_Block((1,1),num_filters=num_filters,strides=strides,auto_pad=True,padding_mode='zero',normalization='batch',activation=None,use_bias=use_bias,name=name + '_downsample') return ShortCut2d(Sequential(Conv2d_Block((3,3),num_filters=num_filters,strides=strides,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu',use_bias=use_bias,name=name + '_0_conv'), Conv2d_Block((3,3),num_filters=num_filters,strides=1,auto_pad=True,padding_mode='zero',normalization='batch',activation=None,use_bias=use_bias,name=name + '_1_conv')), shortcut,activation='relu')
[docs]def bottleneck(num_filters=64,strides=1,expansion = 4,conv_shortcut=True,use_bias=False,name=''): #width = int(num_filters * (base_width / 64.)) * 1#groups' shortcut = Identity() shortcut_name='Identity' if strides>1 or conv_shortcut is True: shortcut =Conv2d_Block((1,1),num_filters=num_filters*expansion,strides=strides,auto_pad=True,padding_mode='zero',normalization='batch',activation=None,use_bias=use_bias,name=name + '_downsample') shortcut_name = 'downsample' return ShortCut2d({'branch1':Sequential(Conv2d_Block((1,1),num_filters=num_filters ,strides=strides,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu',use_bias=use_bias,name=name + '_0_conv'), Conv2d_Block((3, 3), num_filters=num_filters , strides=1, auto_pad=True,padding_mode='zero',normalization='batch', activation='relu',use_bias=use_bias,name=name + '_1_conv'), Conv2d_Block((1,1),num_filters=num_filters*expansion,strides=1,auto_pad=True,padding_mode='zero',normalization='batch',activation=None,use_bias=use_bias,name=name + '_2_conv')), shortcut_name:shortcut},activation='relu')
# def _resnet(arch, block, layers, pretrained, progress, **kwargs): # model = ResNet(block, layers, **kwargs) # if pretrained: # state_dict = load_state_dict_from_url(model_urls[arch], # progress=progress) # model.load_state_dict(state_dict) # return model
[docs]def ResNet(block, layers, input_shape=(3, 224, 224), num_classes=1000, use_bias=False, zero_init_residual=False, width_per_group=64, replace_stride_with_dilation=None, include_top=True, model_name='', **kwargs): """Instantiates the ResNet, ResNetV2, and ResNeXt architecture. Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is the one specified in your Keras config at `~/.keras/keras.json`. Args stack_fn: a function that returns output tensor for the stacked residual blocks. preact: whether to use pre-activation or not (True for ResNetV2, False for ResNet and ResNeXt). use_bias: whether to use biases for convolutional layers or not (True for ResNet and ResNetV2, False for ResNeXt). model_name: string, model name. include_top: whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), 'imagenet' (pre-training on ImageNet), or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified if `include_top` is False (otherwise the input shape has to be `(224, 224, 3)` (with `channels_last` data format) or `(3, 224, 224)` (with `channels_first` data format). It should have exactly 3 inputs channels. pooling: optional pooling mode for feature extraction when `include_top` is `False`. - `None` means that the output of the model will be the 4D tensor output of the last convolutional layer. - `avg` means that global average pooling will be applied to the output of the last convolutional layer, and thus the output of the model will be a 2D tensor. - `max` means that global max pooling will be applied. classes: optional number of classes to classify images into, only to be specified if `include_top` is True, and if no `weights` argument is specified. Returns A Keras model instance. Raises ValueError: in case of invalid argument for `weights`, or invalid input shape. """ def _make_layer(block, num_filters, blocklayers, strides=1, dilate=False,use_bias=use_bias,layer_name=''): conv_shortcut=False if strides!=1 or block is bottleneck: conv_shortcut=True layers = [] layers.append(block(num_filters=num_filters, strides=strides, expansion = 4, conv_shortcut=conv_shortcut,use_bias=use_bias, name=layer_name+'.0')) for k in range(1, blocklayers): layers.append(block(num_filters=num_filters, strides=1, expansion = 4, conv_shortcut=False, use_bias=use_bias,name=layer_name+'.{0}'.format(k))) laters_block=Sequential(*layers) laters_block.name=layer_name return laters_block flow_list=[] resnet = Sequential() resnet.add_module('first_block',Conv2d_Block((7,7),64,strides=2,use_bias=use_bias,auto_pad=True,padding_mode='zero',normalization='batch',activation='relu',name='first_block')) resnet.add_module('maxpool',(MaxPool2d((3,3),strides=2,auto_pad=True,padding_mode='zero'))) resnet.add_module('layer1',(_make_layer(block, 64, layers[0],strides=1, dilate=None,use_bias=use_bias,layer_name='layer1' ))) resnet.add_module('layer2',(_make_layer(block, 128, layers[1], strides=2, dilate=None,use_bias=use_bias,layer_name='layer2' ))) resnet.add_module('layer3',(_make_layer(block, 256, layers[2], strides=2, dilate=None,use_bias=use_bias,layer_name='layer3' ))) resnet.add_module('layer4' ,(_make_layer(block, 512, layers[3], strides=2, dilate=None,use_bias=use_bias,layer_name='layer4' ))) resnet.add_module('avg_pool',GlobalAvgPool2d(name='avg_pool')) if include_top: resnet.add_module('fc',Dense(num_classes,activation=None,name='fc')) resnet.add_module('softmax', SoftMax(name='softmax')) resnet.name=model_name model=ImageClassificationModel(input_shape=input_shape,output=resnet) model.signature = get_signature(model.model.forward) with open(os.path.join(os.path.dirname(os.path.abspath(__file__)) ,'imagenet_labels1.txt'), 'r', encoding='utf-8-sig') as f: labels = [l.rstrip() for l in f] model.class_names=labels model.preprocess_flow=[resize((input_shape[1],input_shape[2]),keep_aspect=True),normalize(0,255),normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])] #model.summary() return model
# # def ResNet18(include_top=True, # weights='imagenet', # input_shape=None, # classes=1000, # **kwargs): # if input_shape is not None and len(input_shape)==3: # input_shape=tuple(input_shape) # else: # input_shape=(3, 224, 224) # resnet18 = ResNet(basic_block, [2, 2, 2, 2], input_shape, model_name='resnet18')
[docs]def ResNet50(include_top=True, pretrained=True, input_shape=None, classes=1000, **kwargs): if input_shape is not None and len(input_shape)==3: input_shape=tuple(input_shape) else: input_shape=(3, 224, 224) resnet50 =ResNet(bottleneck, [3, 4, 6, 3], input_shape,num_classes=classes,include_top=include_top, model_name='resnet50') if pretrained==True: download_model_from_google_drive(model_urls['resnet50'],dirname,'resnet50.pth') recovery_model=torch.load(os.path.join(dirname,'resnet50.pth')) recovery_model.eval() recovery_model.to(_device) if include_top==False: recovery_model.__delitem__(-1) else: if classes!=1000: new_fc = Dense(classes, activation=None, name='fc') new_fc.input_shape=recovery_model.fc.input_shape recovery_model.fc=new_fc resnet50.model=recovery_model resnet50.rebinding_input_output(input_shape) resnet50.signature = get_signature(resnet50.model.forward) return resnet50
[docs]def ResNet101(include_top=True, pretrained=True, input_shape=None, classes=1000, **kwargs): if input_shape is not None and len(input_shape)==3: input_shape=tuple(input_shape) else: input_shape=(3, 224, 224) resnet101 =ResNet(bottleneck, [3, 4, 23, 3], input_shape,num_classes=classes,include_top=include_top, model_name='resnet101') if pretrained==True: download_model_from_google_drive(model_urls['resnet101'],dirname,'resnet101.pth') recovery_model=torch.load(os.path.join(dirname,'resnet101.pth')) recovery_model.eval() recovery_model.to(_device) if include_top == False: recovery_model.__delitem__(-1) else: if classes != 1000: recovery_model.fc = Dense(classes, activation=None, name='fc') resnet101.model=recovery_model resnet101.rebinding_input_output(input_shape) resnet101.signature = get_signature(resnet101.model.forward) return resnet101
[docs]def ResNet152(include_top=True, pretrained=True, input_shape=None, classes=1000, **kwargs): if input_shape is not None and len(input_shape)==3: input_shape=tuple(input_shape) else: input_shape=(3, 224, 224) resnet152 =ResNet(bottleneck, [3, 8, 36, 3], input_shape,num_classes=classes,include_top=include_top, model_name='resnet152') if pretrained==True: download_model_from_google_drive(model_urls['resnet152'],dirname,'resnet152.pth') recovery_model=torch.load(os.path.join(dirname,'resnet152.pth')) recovery_model.eval() recovery_model.to(_device) if include_top == False: recovery_model.__delitem__(-1) else: if classes != 1000: recovery_model.fc = Dense(classes, activation=None, name='fc') resnet152.model=recovery_model resnet152.rebinding_input_output(input_shape) resnet152.signature = get_signature(resnet152.model.forward) return resnet152
# # # resnet34=ResNet(basic_block, [3, 4, 6, 3], (3, 224, 224)) # resnet50=ResNet(bottleneck, [3, 4, 6, 3], (3, 224, 224)) # resnet101=ResNet(bottleneck, [3, 4, 23, 3], (3, 224, 224)) # resnet152=ResNet(bottleneck, [3, 8, 36, 3], (3, 224, 224))