!----------------------------------------------------------------------------------------------------------------------------------- ! This file is part of ReMKiT1D. ! ! ReMKiT1D is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as ! published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. ! ! ReMKiT1D is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. ! ! You should have received a copy of the GNU General Public License along with ReMKiT1D. If not, see <https://www.gnu.org/licenses/>. ! ! Copyright 2023 United Kingdom Atomic Energy Authority (stefan.mijin@ukaea.uk) !----------------------------------------------------------------------------------------------------------------------------------- submodule (calculation_tree_class) calculation_tree_procedures !! author: Stefan Mijin !! !! Contains module procedures associated with calculation tree and node classes implicit none !----------------------------------------------------------------------------------------------------------------------------------- contains !----------------------------------------------------------------------------------------------------------------------------------- module subroutine initNode(this,additiveMode,constant,leafVarIndex,unaryRealParams,& unaryIntParams,unaryLogicalParams,unaryTransformTag) !! Calculation node initialization routine class(CalculationNode) ,intent(inout) :: this logical ,optional ,intent(in) :: additiveMode real(rk) ,optional ,intent(in) :: constant integer(ik) ,optional ,intent(in) :: leafVarIndex real(rk) ,dimension(:) ,optional ,intent(in) :: unaryRealParams integer(ik) ,dimension(:) ,optional ,intent(in) :: unaryIntParams logical ,dimension(:) ,optional ,intent(in) :: unaryLogicalParams character(*) ,optional ,intent(in) :: unaryTransformTag procedure(realArrayFunctionGenParam) ,pointer :: unaryTransform this%kernel%additiveMode = .false. if (present(additiveMode)) then this%kernel%additiveMode = additiveMode end if this%kernel%constant = real(1,kind=rk) if (this%kernel%additiveMode) this%kernel%constant = 0 if (present(constant)) this%kernel%constant = constant this%kernel%leafVarIndex = 0 if (present(leafVarIndex)) this%kernel%leafVarIndex = leafVarIndex if (present(unaryTransformTag)) then this%kernel%unaryTransformationTag = unaryTransformTag if (allocated(this%kernel%unaryTransformationTag)) then call associateFunctionPointer(unaryTransformTag,this%unaryTransform) if (present(unaryIntParams)) this%kernel%unaryIntParams = unaryIntParams if (present(unaryRealParams)) this%kernel%unaryRealParams = unaryRealParams if (present(unaryLogicalParams)) this%kernel%unaryLogicalParams = unaryLogicalParams end if end if call this%makeDefined() end subroutine initNode !----------------------------------------------------------------------------------------------------------------------------------- module subroutine initTree(this,additiveMode,constant,leafVarIndex,unaryRealParams,& unaryIntParams,unaryLogicalParams,unaryTransformTag) !! Calculation tree initialization routine class(CalculationTree) ,intent(inout) :: this logical ,optional ,intent(in) :: additiveMode real(rk) ,optional ,intent(in) :: constant integer(ik) ,optional ,intent(in) :: leafVarIndex real(rk) ,dimension(:) ,optional ,intent(in) :: unaryRealParams integer(ik) ,dimension(:) ,optional ,intent(in) :: unaryIntParams logical ,dimension(:) ,optional ,intent(in) :: unaryLogicalParams character(*) ,optional ,intent(in) :: unaryTransformTag allocate(this%root) call this%root%init(additiveMode,constant,leafVarIndex,unaryRealParams,& unaryIntParams,unaryLogicalParams,unaryTransformTag) call this%makeDefined() end subroutine initTree !----------------------------------------------------------------------------------------------------------------------------------- module subroutine addChild(this,additiveMode,constant,leafVarIndex,unaryRealParams,& unaryIntParams,unaryLogicalParams,unaryTransformTag) !! Initialize a child node of this node with given properties class(CalculationNode) ,target ,intent(inout) :: this logical ,optional ,intent(in) :: additiveMode real(rk) ,optional ,intent(in) :: constant integer(ik) ,optional ,intent(in) :: leafVarIndex real(rk) ,dimension(:) ,optional ,intent(in) :: unaryRealParams integer(ik) ,dimension(:) ,optional ,intent(in) :: unaryIntParams logical ,dimension(:) ,optional ,intent(in) :: unaryLogicalParams character(*) ,optional ,intent(in) :: unaryTransformTag type(CalculationNode) ,pointer :: nodePointer if (.not. associated(this%leftChild)) then allocate(this%leftChild) call this%leftChild%init(additiveMode,constant,leafVarIndex,unaryRealParams,& unaryIntParams,unaryLogicalParams,unaryTransformTag) this%leftChild%parent => this else nodePointer => this%leftChild do if (associated(nodePointer%rightSibling)) then nodePointer => nodePointer%rightSibling cycle end if allocate(nodePointer%rightSibling) call nodePointer%rightSibling%init(additiveMode,constant,leafVarIndex,unaryRealParams,& unaryIntParams,unaryLogicalParams,unaryTransformTag) nodePointer%rightSibling%parent => this exit end do end if end subroutine addChild !----------------------------------------------------------------------------------------------------------------------------------- pure module recursive function evaluateNode(this,inputArray) result(res) !! Recursively evaluate nodes, using the inputArray variables for leaf values class(CalculationNode) ,intent(in) :: this type(RealArray) ,dimension(:) ,intent(in) :: inputArray real(rk) ,allocatable ,dimension(:) :: res if (associated(this%leftChild)) then res = this%leftChild%evaluate(inputArray) else !No check here to make sure that the leaf has a valid variable index res = inputArray(this%kernel%leafVarIndex)%entry end if if (this%kernel%additiveMode) then res = res + this%kernel%constant else res = res * this%kernel%constant end if if (associated(this%unaryTransform)) & res = this%unaryTransform(res,this%kernel%unaryRealParams,this%kernel%unaryIntParams,this%kernel%unaryLogicalParams) if (associated(this%rightSibling)) then if (this%parent%kernel%additiveMode) then res = res + this%rightSibling%evaluate(inputArray) else res = res * this%rightSibling%evaluate(inputArray) end if end if end function evaluateNode !----------------------------------------------------------------------------------------------------------------------------------- pure module function evaluateTree(this,inputArray) result(res) !! Call tree's root node evaluate class(CalculationTree) ,intent(in) :: this type(RealArray) ,dimension(:) ,intent(in) :: inputArray real(rk) ,allocatable ,dimension(:) :: res res = this%root%evaluate(inputArray) end function evaluateTree !----------------------------------------------------------------------------------------------------------------------------------- module function flattenTree(this) result(res) !! Flatten tree into FlatTree object class(CalculationTree) ,intent(in) :: this type(FlatTree) :: res type(CalculationNode) ,pointer :: nodePointer integer(ik) :: numNodes ,parentIndex, currentIndex ,i if (assertions) call assert(this%isDefined(),"Attempted to flatten undefined CalculationTree") numNodes = 1 nodePointer => this%root !Count numNodes do if (associated(nodePointer%leftChild)) then numNodes = numNodes + 1 nodePointer => nodePointer%leftChild cycle end if if (associated(nodePointer%rightSibling)) then numNodes = numNodes + 1 nodePointer => nodePointer%rightSibling cycle end if do if (.not. associated(nodePointer%parent)) exit if (associated(nodePointer%parent%rightSibling)) then nodePointer => nodePointer%parent%rightSibling numNodes = numNodes + 1 exit else nodePointer => nodePointer%parent end if end do if (.not. associated(nodePointer%parent)) exit end do allocate(res%kernels(numNodes)) allocate(res%children(numNodes)) do i = 1,numNodes allocate(res%children(i)%entry(0)) end do allocate(res%parent(numNodes)) parentIndex = 0 currentIndex = 1 res%kernels(currentIndex) = this%root%kernel res%parent(currentIndex) = 0 nodePointer => this%root do if (associated(nodePointer%leftChild)) then res%children(currentIndex)%entry = [res%children(currentIndex)%entry,currentIndex+1] parentIndex = currentIndex currentIndex = currentIndex + 1 nodePointer => nodePointer%leftChild res%kernels(currentIndex) = nodePointer%kernel res%parent(currentIndex) = parentIndex cycle end if if (associated(nodePointer%rightSibling)) then currentIndex = currentIndex + 1 res%children(parentIndex)%entry = [res%children(parentIndex)%entry,currentIndex] nodePointer => nodePointer%rightSibling res%kernels(currentIndex) = nodePointer%kernel res%parent(currentIndex) = parentIndex cycle end if do if (.not. associated(nodePointer%parent)) exit if (associated(nodePointer%parent%rightSibling)) then nodePointer => nodePointer%parent%rightSibling parentIndex = res%parent(parentIndex) currentIndex = currentIndex + 1 res%children(parentIndex)%entry = [res%children(parentIndex)%entry,currentIndex] res%kernels(currentIndex) = nodePointer%kernel res%parent(currentIndex) = parentIndex exit else nodePointer => nodePointer%parent parentIndex = res%parent(parentIndex) end if end do if (.not. associated(nodePointer%parent)) exit end do end function flattenTree !----------------------------------------------------------------------------------------------------------------------------------- module subroutine initFromFlatTree(this,fTree) !! Calculation tree initialization routine using a FlatTree object class(CalculationTree) ,intent(inout) :: this type(FlatTree) ,intent(in) :: fTree type(CalculationNode) ,pointer :: nodePointer integer(ik) :: i ,currentIndex ,parentIndex ,siblingIndex integer(ik) ,dimension(:) ,allocatable :: indexLookup if (assertions) & call assert(.not. this%isDefined(),"Cannot initialize CalculationTree from FlatTree if it is already defined") call this%init(fTree%kernels(1)%additiveMode,& fTree%kernels(1)%constant,& fTree%kernels(1)%leafVarIndex,& fTree%kernels(1)%unaryRealParams,& fTree%kernels(1)%unaryIntParams,& fTree%kernels(1)%unaryLogicalParams,& fTree%kernels(1)%unaryTransformationTag) parentIndex = 0 currentIndex = 1 siblingIndex = 1 nodePointer => this%root do do i = 1,size(fTree%children(currentIndex)%entry) call nodePointer%addChild(fTree%kernels(fTree%children(currentIndex)%entry(i))%additiveMode,& fTree%kernels(fTree%children(currentIndex)%entry(i))%constant,& fTree%kernels(fTree%children(currentIndex)%entry(i))%leafVarIndex,& fTree%kernels(fTree%children(currentIndex)%entry(i))%unaryRealParams,& fTree%kernels(fTree%children(currentIndex)%entry(i))%unaryIntParams,& fTree%kernels(fTree%children(currentIndex)%entry(i))%unaryLogicalParams,& fTree%kernels(fTree%children(currentIndex)%entry(i))%unaryTransformationTag) end do if (associated(nodePointer%leftChild)) then nodePointer => nodePointer%leftChild parentIndex = currentIndex currentIndex = fTree%children(currentIndex)%entry(1) cycle end if if (associated(nodePointer%rightSibling)) then nodePointer => nodePointer%rightSibling indexLookup = findIndices(fTree%children(parentIndex)%entry == currentIndex) siblingIndex = indexLookup(1) + 1 currentIndex = fTree%children(parentIndex)%entry(siblingIndex) cycle end if do if (.not. associated(nodePointer%parent)) exit if (associated(nodePointer%parent%rightSibling)) then nodePointer => nodePointer%parent%rightSibling currentIndex = parentIndex parentIndex = fTree%parent(parentIndex) indexLookup = findIndices(fTree%children(parentIndex)%entry == currentIndex) siblingIndex = indexLookup(1) + 1 currentIndex = fTree%children(parentIndex)%entry(siblingIndex) exit else nodePointer => nodePointer%parent currentIndex = parentIndex parentIndex = fTree%parent(parentIndex) end if end do if (.not. associated(nodePointer%parent)) exit end do end subroutine initFromFlatTree !----------------------------------------------------------------------------------------------------------------------------------- pure module recursive subroutine destroyNode(this) class(CalculationNode),intent(inout) :: this if (associated(this%leftChild)) call this%leftChild%destroy() if (associated(this%rightSibling)) call this%rightSibling%destroy() nullify(this%leftChild) nullify(this%rightSibling) nullify(this%parent) nullify(this%unaryTransform) end subroutine destroyNode !----------------------------------------------------------------------------------------------------------------------------------- elemental module subroutine finalizeCalculationTree(this) type(CalculationTree) ,intent(inout) :: this call this%root%destroy() nullify(this%root) end subroutine finalizeCalculationTree !----------------------------------------------------------------------------------------------------------------------------------- end submodule calculation_tree_procedures !-----------------------------------------------------------------------------------------------------------------------------------