# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import logging
from ..BaseLayer import BaseConf,BaseLayer
from utils.DocInherit import DocInherit
from utils.exceptions import ConfigurationError
import copy
[docs]class Concat3DConf(BaseConf):
    """ Configuration of Concat3D layer
    Args:
        concat3D_axis(1 or 2): which axis to conduct Concat3D, default is 2.
    """
    # init the args
    def __init__(self,**kwargs):
        super(Concat3DConf, self).__init__(**kwargs)
    # set default params
[docs]    @DocInherit
    def default(self):
        self.concat3D_axis = 2 
[docs]    @DocInherit
    def declare(self):
        self.num_of_inputs = -1
        self.input_ranks =[3] 
[docs]    @DocInherit
    def inference(self):
        self.output_dim = copy.deepcopy(self.input_dims[0])
        self.output_dim[-1] = 0
        self.output_dim[1] = 0
        if self.concat3D_axis == 1:
            self.output_dim[-1] = self.input_dims[0][-1]
            self.output_dim[1] = sum([input_dim[1] for input_dim in self.input_dims])
        if self.concat3D_axis == 2:
            self.output_dim[-1] = sum([input_dim[-1] for input_dim in self.input_dims])
            self.output_dim[1] = self.input_dims[0][1]
        super(Concat3DConf, self).inference() 
[docs]    @DocInherit
    def verify(self):
        super(Concat3DConf, self).verify()
        # to check if the ranks of all the inputs are equal
        rank_equal_flag = True
        for i in range(len(self.input_ranks)):
            if self.input_ranks[i] != self.input_ranks[0]:
                rank_equal_flag = False
                break
        if rank_equal_flag == False:
            raise ConfigurationError("For layer Concat3D, the ranks of each inputs should be equal!")
        if self.concat3D_axis == 1:
            # to check if the dimensions of all the inputs are equal
            input_dims = list(self.input_dims)
            dim_equal_flag = True
            for i in range(len(input_dims)):
                if input_dims[i][-1] != input_dims[0][-1]:
                    dim_equal_flag = False
                    break
            if dim_equal_flag == False:
                raise Exception("Concat3D with axis = 1 require that the input dimensions should be the same!")
        # to check if the concat3D_axis is legal
        if self.concat3D_axis not in [1, 2]:
            raise ConfigurationError("For layer Concat3D, the concat axis must be 1 or 2!")  
[docs]class Concat3D(nn.Module):
    """ Concat3D layer to merge sum of sequences(3D representation)
    Args:
        layer_conf (Concat3DConf): configuration of a layer
    """
    def __init__(self,layer_conf):
        super(Concat3D, self).__init__()
        self.layer_conf = layer_conf
        logging.warning("The length Concat3D layer returns is the length of first input")
[docs]    def forward(self, *args):
        """ process inputs
        Args:
            *args: (Tensor): string, string_len, string2, string2_len, ...
                e.g. string (Tensor): [batch_size, seq_len, dim], string_len (Tensor): [batch_size]
        Returns:
            Tensor: [batch_size, seq_len, output_dim], [batch_size]
        """
        result = []
        if self.layer_conf.concat3D_axis == 1:
            for idx, input in enumerate(args):
                if idx % 2 == 0:
                    result.append(input)
        if self.layer_conf.concat3D_axis == 2:
            input_shape = args[0].shape[1]
            for idx, input in enumerate(args):
                if idx % 2 == 0 and input_shape == input.shape[1]:
                    result.append(input)
                else:
                    raise Exception("Concat3D with axis = 2 require that the input sequences length should be the same!")
        return torch.cat(result, self.layer_conf.concat3D_axis), args[1]