# This file is created by Nvidia Corp.
# Modified by PPQ develop team.
# 
# Copyright 2020 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List

import torch
from ppq.core import (EXPORT_DEVICE_SWITCHER, DataType, OperationMeta,
                      TensorMeta, TensorQuantizationConfig,
                      convert_any_to_torch_tensor, ppq_warning)
from ppq.IR import BaseGraph, Operation, Variable
from ppq.IR.morph import GraphDeviceSwitcher
from ppq.IR.quantize import QuantableOperation, QuantableVariable
from ppq.utils.round import ppq_tensor_round

try:
    import tensorrt as trt
except ImportError:
    ppq_warning('TensorRT is not installed, TRT Exporter is disabled.')

from .onnxruntime_exporter import ONNXRUNTIMExporter


class TensorRTExporter(ONNXRUNTIMExporter):

    def insert_quant_dequant_on_variable(
        self, graph: BaseGraph, var: QuantableVariable, op: QuantableOperation,
        config: TensorQuantizationConfig) -> None:
        """
        Insert Quant & Dequant Operation to graph
        This insertion will strictly follows tensorRT format requirement.
        
        Inserted Quant & Dequant op will just between upstream variable and downstream operation,
        
        Example 1, Insert quant & dequant between var1 and op1:
        
        Before insertion:
            var1 --> op1
        
        After insertion:
            var1 --> quant --> generated_var --> dequant --> generated_var --> op1

        Args:
            graph (BaseGraph): PPQ IR graph.
            var (Variable): upstream variable.
            config (TensorQuantizationConfig, optional): quantization config.
            op (Operation, optional): downstream operation.
        
        """
        meta = var.meta

        scale  = convert_any_to_torch_tensor(config.scale, dtype=torch.float32)
        offset = ppq_tensor_round(config.offset).type(torch.int8)

        qt_svar = Variable(name=f'{op.name}_{var.name}' + '_qt_scale', value=scale, is_parameter=True)
        qt_zvar = Variable(name=f'{op.name}_{var.name}' + '_qt_zeropoint', value=offset, is_parameter=True)
        dq_svar = Variable(name=f'{op.name}_{var.name}' + '_dq_scale', value=scale, is_parameter=True)
        dq_zvar = Variable(name=f'{op.name}_{var.name}' + '_dq_zeropoint', value=offset, is_parameter=True)

        qt_op = Operation(name=f'{op.name}_{var.name}' + '_QuantizeLinear', op_type='QuantizeLinear', attributes={})
        dq_op = Operation(name=f'{op.name}_{var.name}' + '_DequantizeLinear', op_type='DequantizeLinear', attributes={})

        graph.insert_op_between_var_and_op(dq_op, up_var=var, down_op=op)
        graph.insert_op_between_var_and_op(qt_op, up_var=var, down_op=dq_op)

        qt_op.inputs.extend([qt_svar, qt_zvar])
        dq_op.inputs.extend([dq_svar, dq_zvar])

        qt_svar.dest_ops.append(qt_op)
        qt_zvar.dest_ops.append(qt_op)
        dq_svar.dest_ops.append(dq_op)
        dq_zvar.dest_ops.append(dq_op)

        graph.append_variable(qt_svar)
        graph.append_variable(qt_zvar)
        graph.append_variable(dq_svar)
        graph.append_variable(dq_zvar)

        # create meta data for qt_op, dq_op
        qt_meta = OperationMeta(
            input_metas    = [TensorMeta(dtype=DataType.FP32, shape=meta.shape), 
                              TensorMeta(dtype=DataType.FP32, shape=config.scale.shape), 
                              TensorMeta(dtype=DataType.FP32, shape=config.offset.shape)],
            output_metas   = [TensorMeta(dtype=DataType.INT8, shape=meta.shape)],
            operation_name = qt_op.name, operation_type=qt_op.type, executing_order=-1)
        dq_meta = OperationMeta(
            input_metas    = [TensorMeta(dtype=DataType.INT8, shape=meta.shape), 
                              TensorMeta(dtype=DataType.FP32, shape=config.scale.shape), 
                              TensorMeta(dtype=DataType.FP32, shape=config.offset.shape)],
            output_metas   = [TensorMeta(dtype=DataType.FP32, shape=meta.shape)],
            operation_name = dq_op.name, operation_type=dq_op.type, executing_order=-1)

        qt_op.meta_data = qt_meta
        dq_op.meta_data = dq_meta

    def prepare_graph(self, graph: BaseGraph) -> BaseGraph:
        """
        TensorRT Demands a custimized QAT model format as it input.
            With this particular format, we only need export input quant config from ppq, and only a
            part of operations is required  to dump its quant config.
        
        Which are:
            _DEFAULT_QUANT_MAP = [_quant_entry(torch.nn, "Conv1d", quant_nn.QuantConv1d),
                      _quant_entry(torch.nn, "Conv2d", quant_nn.QuantConv2d),
                      _quant_entry(torch.nn, "Conv3d", quant_nn.QuantConv3d),
                      _quant_entry(torch.nn, "ConvTranspose1d", quant_nn.QuantConvTranspose1d),
                      _quant_entry(torch.nn, "ConvTranspose2d", quant_nn.QuantConvTranspose2d),
                      _quant_entry(torch.nn, "ConvTranspose3d", quant_nn.QuantConvTranspose3d),
                      _quant_entry(torch.nn, "Linear", quant_nn.QuantLinear),
                      _quant_entry(torch.nn, "LSTM", quant_nn.QuantLSTM),
                      _quant_entry(torch.nn, "LSTMCell", quant_nn.QuantLSTMCell),
                      _quant_entry(torch.nn, "AvgPool1d", quant_nn.QuantAvgPool1d),
                      _quant_entry(torch.nn, "AvgPool2d", quant_nn.QuantAvgPool2d),
                      _quant_entry(torch.nn, "AvgPool3d", quant_nn.QuantAvgPool3d),
                      _quant_entry(torch.nn, "AdaptiveAvgPool1d", quant_nn.QuantAdaptiveAvgPool1d),
                      _quant_entry(torch.nn, "AdaptiveAvgPool2d", quant_nn.QuantAdaptiveAvgPool2d),
                      _quant_entry(torch.nn, "AdaptiveAvgPool3d", quant_nn.QuantAdaptiveAvgPool3d),]

        Reference:
        https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/pytorch_quantization/quant_modules.py
        
        ATTENTION: MUST USE TENSORRT QUANTIZER TO GENERATE A TENSORRT MODEL.
        """
        # remove switchers.
        if not EXPORT_DEVICE_SWITCHER:
            processer = GraphDeviceSwitcher(graph)
            processer.remove_switcher()
        
        # find all quantable operations:
        for operation in [op for op in graph.operations.values()]:
            if not isinstance(operation, QuantableOperation): continue
            if operation.type in {'Conv', 'Gemm', 'ConvTranspose'}:
                # for Conv, Gemm, ConvTranspose, TensorRT wants their weight to be quantized,
                # however bias remains fp32.
                assert len(operation.config.input_quantization_config) >= 2, (
                    f'Oops seems operation {operation.name} has less than 2 input.')
                
                i_config, w_config = operation.config.input_quantization_config[: 2]
                i_var, w_var       = operation.inputs[: 2]

                self.insert_quant_dequant_on_variable(graph=graph, var=i_var, config=i_config, op=operation)
                self.insert_quant_dequant_on_variable(graph=graph, var=w_var, config=w_config, op=operation)

            elif operation.type in {'AveragePool', 'GlobalAveragePool'}:
                # for Average pool, tensorRT requires their input quant config.
                
                assert len(operation.config.input_quantization_config) >= 1, (
                    f'Oops seems operation {operation.name} has less than 1 input.')
                i_config = operation.config.input_quantization_config[0]
                i_var    = operation.inputs[0]
                
                self.insert_quant_dequant_on_variable(graph=graph, var=i_var, config=i_config, op=operation)

            else:
                ppq_warning(f'Do not support export quantized operation {operation.name} to TensorRT, '
                            'This operation is expected run with fp32 mode') 
        return graph

    @ property
    def required_opsets(self) -> Dict[str, int]:
        extra_domain_versions = [("ai.onnx", 11)] # must be opset 11
        return dict(extra_domain_versions)

    def export(self, file_path: str, graph: BaseGraph, 
               config_path: str = None, input_shapes: Dict[str, list] = None) -> None:
        # step 1, export onnx file.
        super().export(file_path=file_path, graph=graph, config_path=None)

        # step 2, convert onnx file to tensorRT engine.
        try: 
            TRT_LOGGER = trt.Logger(trt.Logger.INFO)
        except Exception as e:
            raise Exception('TensorRT is not successfully loaded, therefore ppq can not export tensorRT engine directly, '
                            f'a model named {file_path} has been created so that you can send it to tensorRT manually.')
        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        network_flags = network_flags | (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION))
        
        # step 3, build profile input shape
        # Notice that for each input you should give 3 shapes: (min shape), (opt shape), (max shape)
        if input_shapes is None:
            input_shapes = {input_var.name: [input_var.meta.shape, input_var.meta.shape, input_var.meta.shape] 
                            for input_var in graph.inputs.values()}

        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags=network_flags) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
            
            with open(file_path, 'rb') as model:
                if not parser.parse(model.read()):
                    print ('ERROR: Failed to parse the ONNX file.')
                    for error in range(parser.num_errors):
                        print (parser.get_error(error))
                    return None

            config = builder.create_builder_config()
            config.max_workspace_size = 1 << 30
            config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8)
            
            profile = builder.create_optimization_profile()
            
            # build TensorRT Profile
            for idx in range(network.num_inputs):
                inp = network.get_input(idx)

                if inp.is_shape_tensor:
                    if inp.name in input_shapes:
                        shapes = input_shapes[inp.name]
                    else: shapes = None

                    if not shapes:
                        shapes = [(1, ) * inp.shape[0]] * 3
                        print("Setting shape input to {:}. "
                              "If this is incorrect, for shape input: {:}, "
                              "please provide tuples for min, opt, "
                              "and max shapes".format(shapes[0], inp.name))
                    
                    if not isinstance(shapes, list) or len(shapes) != 3:
                        raise ValueError(f'Profiling shape must be a list with exactly 3 shapes(tuples of int), '
                                         f'while recevied a {type(shapes)} for input {inp.name}, check your input again.')
                    
                    min, opt, max = shapes
                    profile.set_shape_input(inp.name, min, opt, max)
                
                elif -1 in inp.shape:
                    if inp.name in input_shapes:
                        shapes = input_shapes[inp.name]
                    else: shapes = None

                    if not shapes:
                        shapes = [(1 if s <= 0 else s for s in inp.shape)] * 3

                    min, opt, max = shapes
                    profile.set_shape(inp.name, min, opt, max)

            config.add_optimization_profile(profile)
            trt_engine = builder.build_engine(network, config)
            
            # end for
        
        # end with

        engine_file = file_path.replace('.onnx', '.engine')
        with open(engine_file, "wb") as file:
            file.write(trt_engine.serialize())