matrix_term_abstract_procedures.f90 Source File


Source Code

!-----------------------------------------------------------------------------------------------------------------------------------
! This file is part of ReMKiT1D.
!
! ReMKiT1D is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as 
! published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
!
! ReMKiT1D is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of 
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
!
! You should have received a copy of the GNU General Public License along with ReMKiT1D. If not, see <https://www.gnu.org/licenses/>. 
!
! Copyright 2023 United Kingdom Atomic Energy Authority (stefan.mijin@ukaea.uk)
!-----------------------------------------------------------------------------------------------------------------------------------
submodule (matrix_term_abstract_class) matrix_term_abstract_procedures
    !! author: Stefan Mijin 
    !! 
    !! Contains module procedures associated with the abstract matrix term class

implicit none

!-----------------------------------------------------------------------------------------------------------------------------------
contains
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine unityRow(this,varCont,rowVals,indexingData) 
    !! Default row function. In general should use the passed variable container, row buffer, and matrix indexing data
    !! to provide values which is only a function of the row 

    class(MatrixTerm)             ,intent(inout)   :: this
    type(VariableContainer)       ,intent(in)      :: varCont
    real(rk) ,dimension(:)        ,intent(inout)   :: rowVals
    type(MatrixTermIndexingData)  ,intent(in)      :: indexingData

    if (assertions) then
        call assertPure(this%isDefined(),"Called unityRow with undefined matrix term object")
        call assertPure(varCont%isDefined(),"Called unityRow by passing undefined variable container")
    end if

    rowVals = real(1,kind=rk)

end subroutine unityRow 
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine unityCol(this,varCont,colVals,indexingData) 
    !! Default column function. In general should provide values for each column 

    class(MatrixTerm)             ,intent(inout) :: this
    type(VariableContainer)       ,intent(in)    :: varCont
    type(RealArray) ,dimension(:) ,intent(inout) :: colVals
    type(MatrixTermIndexingData)  ,intent(in)    :: indexingData

    integer(ik) :: i 

    if (assertions) then
        call assertPure(this%isDefined(),"Called unityCol with undefined matrix term object")
        call assertPure(varCont%isDefined(),"Called unityCol by passing undefined variable container")
    end if

    do i = 1,size(colVals)
        colVals(i)%entry = real(1,kind=rk)
    end do

end subroutine unityCol  
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine calculateValues(this,varCont) 
    !! Default matrix entry calculation routine - loops over each row and applies calculateRows and colFun to variables in provided container

    class(MatrixTerm)       ,intent(inout)  :: this
    type(VariableContainer) ,intent(in)     :: varCont

    integer(ik) :: i

    if (assertions) then
        call assertPure(this%isDefined(),"Attempted to calculate values of undefined matrix term object")
        call assertPure(varCont%isDefined(),&
        "Attempted to calculate values of matrix term object by passing undefined variable container")
    end if

    if (.not. this%fixedMatrixCalculated) then

        if (this%hasNonTrivialRowFun) then
            if (this%hasNonTrivialColFun) then 

                call this%calculateRows(varCont,this%rowBuffer,this%indexingData)
                call this%calculateCols(varCont,this%colBuffer,this%indexingData)
                do i = 1,size(this%rowData%values)
                    this%rowData%values(i)%entry = this%normalizationConst * this%multConst(i)%entry &
                                                * this%rowBuffer(i) &
                                                * this%colBuffer(i)%entry
                end do

            else

                call this%calculateRows(varCont,this%rowBuffer,this%indexingData)
                do i = 1,size(this%rowData%values)
                    this%rowData%values(i)%entry = this%normalizationConst * this%multConst(i)%entry &
                                                * this%rowBuffer(i) 
                end do

            end if

        else

            if (this%hasNonTrivialColFun) then

                call this%calculateCols(varCont,this%colBuffer,this%indexingData)
                do i = 1,size(this%rowData%values)
                    this%rowData%values(i)%entry = this%normalizationConst * this%multConst(i)%entry &
                                                * this%colBuffer(i)%entry
                end do

            else

                do i = 1,size(this%rowData%values)
                    this%rowData%values(i)%entry = this%normalizationConst * this%multConst(i)%entry 

                end do

            end if

        end if

    end if

    if (this%fixedMatrix) this%fixedMatrixCalculated = .true.

end subroutine calculateValues
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setNormalizationConst(this,norm) 
    !! Setter for normalizationConst

    class(MatrixTerm)       ,intent(inout)  :: this
    real(rk)                ,intent(in)     :: norm

    this%normalizationConst = norm

end subroutine setNormalizationConst
!-----------------------------------------------------------------------------------------------------------------------------------
pure module function getNormalizationConst (this) result(norm)
    !! Getter for normalizationConst

    class(MatrixTerm)  ,intent(in) :: this
    real(rk)                       :: norm

    if (assertions) call assertPure(this%isDefined(),"Requested normalization constant from undefined matrix term object")
    norm = this%normalizationConst

end function getNormalizationConst
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setReqVars(this,rowReqVars,colReqVars,varCont) 
    !! Set variable names required by the row and column functions and find their indices in variable container

    class(MatrixTerm)               ,intent(inout)  :: this
    type(StringArray) ,dimension(:) ,intent(in)     :: rowReqVars
    type(StringArray) ,dimension(:) ,intent(in)     :: colReqVars
    type(VariableContainer)         ,intent(in)     :: varCont

    integer(ik) :: i

    if (assertions) then 
        call assertPure(varCont%isDefined(),"Called setReqVars for matrix term by passing undefined variable container")
    end if

    allocate(this%indexingData%rowReqVarIndices(size(rowReqVars)))
    allocate(this%indexingData%colReqVarIndices(size(colReqVars)))

    do i = 1, size(rowReqVars)
        this%indexingData%rowReqVarIndices(i) = varCont%getVarIndex(rowReqVars(i)%string)
    end do

    do i = 1, size(colReqVars)
        this%indexingData%colReqVarIndices(i) = varCont%getVarIndex(colReqVars(i)%string)
    end do
end subroutine setReqVars
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setEvolvedAndImplicitVar(this,rowVarName,colVarName,varCont) 
    !! Set evolved (row) and implicit (column) variable names, and check if evolved variable is stationary using varCont

    class(MatrixTerm)         ,intent(inout)  :: this
    character(*)              ,intent(in)     :: rowVarName
    character(*)              ,intent(in)     :: colVarName
    type(VariableContainer)   ,intent(in)     :: varCont

    this%indexingData%rowVarName = rowVarName
    this%indexingData%colVarName = colVarName 

    this%stationaryEvolvedVar = varCont%isStationary(rowVarName)

end subroutine setEvolvedAndImplicitVar
!-----------------------------------------------------------------------------------------------------------------------------------
module subroutine initRowData(this,rowCoords,rowToColMapping,indexingObj) 
    !! Initialize row data objects based on a set of evolved global coordinates, the row/col var name, and a function that returns 
    !! column coordinates given a row coordinate input. Requires a reference Indexing object.

    class(MatrixTerm)            ,intent(inout)  :: this
    integer(ik)  ,dimension(:,:) ,intent(in)     :: rowCoords
    procedure(coordMapping)                      :: rowToColMapping
    type(Indexing)               ,intent(in)     :: indexingObj

    integer(ik) :: i ,j

    type(IntArray) ,allocatable ,dimension(:) :: mappedColIndices ,mappedColIndicesLocal
    integer(ik) ,allocatable ,dimension(:,:)  :: allCombinationsMappedInd 

    if (assertions) then 
        call assert(indexingObj%isDefined(),"initRowData for matrix term called with undefined indexing object")
        call assert((size(rowCoords,1) == 1) .or. (size(rowCoords,1) == 3),"Row coordinate entries passed to initRowData must&
        & be of dimension (1,:) or (3,:) depending on whether the row variable is a distribution")
    end if

    call this%rowData%init()
    allocate(this%indexingData%rowDataCoordsGlobal(size(rowCoords,2)))
    allocate(this%indexingData%rowDataCoordsLocal(size(rowCoords,2)))
    allocate(this%multConst(size(rowCoords,2)))
    allocate(this%colBuffer(size(rowCoords,2)))
    allocate(this%rowBuffer(size(rowCoords,2)))
    allocate(this%indexingData%localRowIndices(size(rowCoords,2)))
    allocate(this%indexingData%localColIndices(size(rowCoords,2)))

    this%rowBuffer = real(1,kind=rk)
    if (size(rowCoords,1) == 1) then
        do i = 1,size(rowCoords,2)
            mappedColIndices = rowToColMapping(rowCoords(:,i))
            mappedColIndicesLocal = mappedColIndices
            allCombinationsMappedInd = allCombinations(mappedColIndices)
            call this%rowData%addRow(indexingObj%findIndex(this%indexingData%rowVarName,rowCoords(1,i)),&
                                     indexingObj%mapToGlobalIndices(this%indexingData%colVarName,mappedColIndices))

            this%indexingData%rowDataCoordsGlobal(i)%rowCoords = rowCoords(:,i)
            this%indexingData%localRowIndices(i) = indexingObj%findLocalXIndex(rowCoords(1,i))

            this%indexingData%rowDataCoordsLocal(i)%rowCoords = rowCoords(:,i) 
            this%indexingData%rowDataCoordsLocal(i)%rowCoords(1) = indexingObj%findLocalXIndex(rowCoords(1,i))
            this%indexingData%rowDataCoordsGlobal(i)%colCoords = allCombinationsMappedInd

            do j = 1, size(mappedColIndicesLocal(1)%entry)
                mappedColIndicesLocal(1)%entry(j) = indexingObj%findLocalXIndex(mappedColIndicesLocal(1)%entry(j),&
                                                                                locXInd=rowCoords(1,i))
            end do

            this%indexingData%rowDataCoordsLocal(i)%colCoords = allCombinations(mappedColIndicesLocal)

            allocate(this%multConst(i)%entry(size(this%rowData%values(i)%entry)))
            allocate(this%colBuffer(i)%entry(size(this%rowData%values(i)%entry)))
            this%multConst(i)%entry = real(1.0d00,kind=rk)
            this%colBuffer(i)%entry = real(1.0d00,kind=rk)

            allocate(this%indexingData%localColIndices(i)%entry(size(this%rowData%values(i)%entry)))

            if (size(mappedColIndices) == 1) then
                do j = 1, size(this%indexingData%localColIndices(i)%entry)
                    this%indexingData%localColIndices(i)%entry(j) = indexingObj%findLocalXIndex(allCombinationsMappedInd(1,j),&
                                                                                                locXInd=rowCoords(1,i))
                end do
            else 
                do j = 1, size(this%indexingData%localColIndices(i)%entry)
                    this%indexingData%localColIndices(i)%entry(j) = indexingObj%findDistIndex(allCombinationsMappedInd(1,j),&
                                                                                 allCombinationsMappedInd(2,j),&
                                                                                 allCombinationsMappedInd(3,j),&
                                                                                 .true.,&
                                                                                 locXInd=rowCoords(1,i))
                end do

            end if

        end do
    else 

        do i = 1,size(rowCoords,2)
            mappedColIndices = rowToColMapping(rowCoords(:,i))
            mappedColIndicesLocal = mappedColIndices
            allCombinationsMappedInd = allCombinations(mappedColIndices)
            call this%rowData%addRow(indexingObj%findIndex(this%indexingData%rowVarName,rowCoords(1,i),&
                                                                           rowCoords(2,i),&
                                                                           rowCoords(3,i)),&
                                     indexingObj%mapToGlobalIndices(this%indexingData%colVarName,mappedColIndices))

            this%indexingData%rowDataCoordsGlobal(i)%rowCoords = rowCoords(:,i)
            this%indexingData%localRowIndices(i) = indexingObj%findDistIndex(rowCoords(1,i),&
                                                                rowCoords(2,i),&
                                                                rowCoords(3,i),&
                                                                .true.)
            this%indexingData%rowDataCoordsGlobal(i)%colCoords = allCombinationsMappedInd

            this%indexingData%rowDataCoordsLocal(i)%rowCoords = rowCoords(:,i) 
            this%indexingData%rowDataCoordsLocal(i)%rowCoords(1) = indexingObj%findLocalXIndex(rowCoords(1,i))

            do j = 1, size(mappedColIndicesLocal(1)%entry)
                mappedColIndicesLocal(1)%entry(j) = indexingObj%findLocalXIndex(mappedColIndicesLocal(1)%entry(j),&
                                                                                locXInd=rowCoords(1,i))
            end do

            this%indexingData%rowDataCoordsLocal(i)%colCoords = allCombinations(mappedColIndicesLocal)

            allocate(this%multConst(i)%entry(size(this%rowData%values(i)%entry)))
            this%multConst(i)%entry = real(1.0d00,kind=rk)

            allocate(this%indexingData%localColIndices(i)%entry(size(this%rowData%values(i)%entry)))

            if (size(mappedColIndices) == 1) then
                do j = 1, size(this%indexingData%localColIndices(i)%entry)
                    this%indexingData%localColIndices(i)%entry(j) = indexingObj%findLocalXIndex(allCombinationsMappedInd(1,j),&
                                                                                                locXInd=rowCoords(1,i))
                end do
            else 
                do j = 1, size(this%indexingData%localColIndices(i)%entry)
                    this%indexingData%localColIndices(i)%entry(j) = indexingObj%findDistIndex(allCombinationsMappedInd(1,j),&
                                                                                allCombinationsMappedInd(2,j),&
                                                                                allCombinationsMappedInd(3,j),&
                                                                                .true.,&
                                                                                locXInd=rowCoords(1,i))
                end do

            end if
        end do

    end if

end subroutine initRowData
!-----------------------------------------------------------------------------------------------------------------------------------
module subroutine addRowDataToPreallocationData(this,petscPreallocData) 
    !! Add this term's row data to a PETSc preallocation object

    class(MatrixTerm)            ,intent(inout)  :: this
    type(PETScPreallocationData) ,intent(inout)  :: petscPreallocData

    if (assertions) then 
        call assert(this%isDefined(),"Attempted to add undefined matrix term row data to PETSc preallocation data")
        call assert(petscPreallocData%isDefined(),"Attempted to add matrix term row data to undefined PETSc preallocation data")
    end if

    call petscPreallocData%addRowDataToPattern(this%rowData)

end subroutine addRowDataToPreallocationData
!-----------------------------------------------------------------------------------------------------------------------------------
module subroutine addRowDataPatternToController(this,petscCont) 
    !! Add this term's row data to a petsc preallocation object in PETSc controller

    class(MatrixTerm)     ,intent(in)     :: this
    type(PETScController) ,intent(inout)  :: petscCont

    call petscCont%addRowDataToPreallocation(this%rowData)

end subroutine addRowDataPatternToController
!-----------------------------------------------------------------------------------------------------------------------------------
module subroutine addRowValuesToPETScMatrix(this,petscCont,mult,petscGroup) 
    !! Add this term's row values to a petsc matrix object in PETSc controller, multiplied by mult. If evolved variable is
    !! stationary mult is ignored and set to -1

    class(MatrixTerm)     ,intent(in)     :: this
    type(PETScController) ,intent(inout)  :: petscCont
    real(rk)              ,intent(in)     :: mult
    integer(ik) ,optional ,intent(in)     :: petscGroup

    real(rk)                              :: usedMult
    integer(ik)                           :: usedGroup

    if (assertions) call assert(this%isDefined(),"Attempted to add undefined matrix term values to PETSc preallocation data")

    usedMult = mult 
    if (this%stationaryEvolvedVar) usedMult = real(-1.0d0,kind=rk)

    usedGroup = 1
    if (present(petscGroup)) usedGroup = petscGroup
    call petscCont%addRowValuesToMatrix(this%rowData,usedMult,usedGroup) 

end subroutine addRowValuesToPETScMatrix
!-----------------------------------------------------------------------------------------------------------------------------------
pure module function evaluateMatTerm (this,varCont) result(res)
    !! Get explicit value for the term by crudely dotting each sparse row with the locally stored implicit variable vector in varCont

    class(MatrixTerm)                    ,intent(in) :: this
    type(VariableContainer)              ,intent(in) :: varCont
    real(rk) ,allocatable ,dimension(:)              :: res

    integer(ik) :: i ,varIndexRow ,varIndexCol

    if (assertions) then
        call assertPure(this%isDefined(),"Attempted to calculate explicit version of undefined matrix term object")
        call assertPure(varCont%isDefined(),&
        "Attempted to calculate explicit version of matrix term object by passing undefined variable container")
    end if

    varIndexRow = varCont%getVarIndex(this%indexingData%rowVarName)
    varIndexCol = varCont%getVarIndex(this%indexingData%colVarName) 

    allocate(res,mold=varCont%variables(varIndexRow)%entry)
    res = 0

    do i = 1, size(this%rowData%values)
        res(this%indexingData%localRowIndices(i)) = dot_product(this%rowData%values(i)%entry,&
                                                varCont%variables(varIndexCol)%entry(this%indexingData%localColIndices(i)%entry))
    end do

end function evaluateMatTerm
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setMultConst(this,multConst) 
    !! Setter for multConst

    class(MatrixTerm)              ,intent(inout)  :: this
    type(RealArray) ,dimension(:)  ,intent(in)     :: multConst

    this%multConst = multConst

end subroutine setMultConst
!-----------------------------------------------------------------------------------------------------------------------------------
pure module function getMultConst (this) result(multConst)
    !! Getter for multConst

    class(MatrixTerm)                          ,intent(in) :: this
    type(RealArray) ,allocatable ,dimension(:)             :: multConst

    if (assertions) call assertPure(this%isDefined(),"Called getMultConst on undefined matrix term")

    multConst = this%multConst

end function getMultConst
!-----------------------------------------------------------------------------------------------------------------------------------
pure module function getRowVarName(this) result(name)
    !! Get name of the evolved variable of this term

    class(MatrixTerm)                    ,intent(in) :: this
    character(:) ,allocatable                        :: name

    if (assertions) call assertPure(this%isDefined(),"Called getRowVarName on undefined matrix term")

    name = this%indexingData%rowVarName 

end function getRowVarName
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setFixedMatrix(this,isFixed) 
    !! Set whether the matrix is fixed -> only calculated once

    class(MatrixTerm)              ,intent(inout)  :: this
    logical                        ,intent(in)     :: isFixed

    this%fixedMatrix = isFixed

end subroutine setFixedMatrix
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setNonTrivialRowFun(this,nontriv) 
    !! Set whether the matrix has a non-trivial row function

    class(MatrixTerm)              ,intent(inout)  :: this
    logical                        ,intent(in)     :: nontriv

    this%hasNonTrivialRowFun = nontriv

end subroutine setNonTrivialRowFun
!-----------------------------------------------------------------------------------------------------------------------------------
pure module subroutine setNonTrivialColFun(this,nontriv) 
    !! Set whether the matrix has a non-trivial col function

    class(MatrixTerm)              ,intent(inout)  :: this
    logical                        ,intent(in)     :: nontriv

    this%hasNonTrivialColFun = nontriv

end subroutine setNonTrivialColFun
!-----------------------------------------------------------------------------------------------------------------------------------
pure module function getIndexingData(this) result(indData)
    !! Getter for indexingData

    class(MatrixTerm)                    ,intent(in) :: this
    type(MatrixTermIndexingData)                     :: indData

    if (assertions) call assertPure(this%isDefined(),"Called getIndexingData on undefined matrix term")

    indData = this%indexingData 
    
end function getIndexingData
!-----------------------------------------------------------------------------------------------------------------------------------
pure module function getRowData(this) result(rowData)
    !! Getter for rowData

    class(MatrixTerm)           ,intent(in) :: this
    type(SparseRowData)                     :: rowData
    
    if (assertions) call assertPure(this%isDefined(),"Called getRowData on undefined matrix term")

    rowData = this%rowData 
    
end function getRowData
!-----------------------------------------------------------------------------------------------------------------------------------
module subroutine matTermUpdate(this,varCont,modelData,hostModel) 
    !! Default matrix term update, call matrixTermUpdate 

    class(MatrixTerm)               ,intent(inout)  :: this
    type(VariableContainer)         ,intent(in)     :: varCont
    class(ModelboundData) ,optional ,intent(in)     :: modelData
    class(ModelSurrogate) ,optional ,intent(in)     :: hostModel

    if (present(modelData)) then 
        if (present(hostModel)) then 
            call this%matrixTermUpdate(varCont,this%indexingData,modelData=modelData,hostModel=hostModel)
        else
            call this%matrixTermUpdate(varCont,this%indexingData,modelData=modelData)
        end if
    else
        if (present(hostModel)) then 
            call this%matrixTermUpdate(varCont,this%indexingData,hostModel=hostModel)
        else
            call this%matrixTermUpdate(varCont,this%indexingData)
        end if
    end if

end subroutine matTermUpdate
!-----------------------------------------------------------------------------------------------------------------------------------
module subroutine matrixNoUpdate(this,varCont,indexingData,modelData,hostModel) 
    !! Default MatrixTerm updateMatTerm function - does nothing

    class(MatrixTerm)               ,intent(inout)  :: this
    type(VariableContainer)         ,intent(in)     :: varCont
    type(MatrixTermIndexingData)    ,intent(in)     :: indexingData
    class(ModelboundData) ,optional ,intent(in)     :: modelData
    class(ModelSurrogate) ,optional ,intent(in)     :: hostModel

end subroutine matrixNoUpdate
!-----------------------------------------------------------------------------------------------------------------------------------
!-----------------------------------------------------------------------------------------------------------------------------------
end submodule matrix_term_abstract_procedures
!-----------------------------------------------------------------------------------------------------------------------------------