しずくぶろぐ

競技ぷろぐらみんぐしたり、なんかしたりします

不定期解説記事企画 ARC008-D タコヤキオイシクナール #雫ぷよ

こんにちは。研究室の先輩と仲良くなった気がする綿谷雫です。今日は ARC008Dの解説をしたいと思います。

今回はあるものを試したいと思って解いたものなので思考回路的なものはないです。ネタバレを聞かずに過ごしたい人はブラウザバックをしてください

問題のリンクは下に貼っておきます。

atcoder.jp

問題概要

美味しさxのたこ焼きが通るとそのたこ焼きの美味しさがax+bになる摩訶不思議な箱がならんでいて稼働中にちょくだいさんがその箱のa,bの値をいたずらするので出てくるたこやきの美味しさの最大値と最小値を求めましょう、という問題です

前置き、セグメントツリーとは

ネタバレです。こういう問題はセグメントツリーで解くらしいです。というわけでセグメントツリーの解説をします。

セグメントツリーとはデータ構造です。セグメントツリーでは普通に配列に入れるよりも素早くある計算ができます。ある計算というのは「区間内の要素を使った、計算する順序によって計算結果が変わらないような計算」です。例えば\displaystyle{a _ 1 ,a _ 2 ,\cdots , a _ {16} }というものがあったとき、セグメントツリーを使うと\displaystyle{ a _ 3} から\displaystyle{ a _ 8}などある区間のものを足した値を素早く計算できます。

どうしてセグメントツリーでは素早く計算ができるんでしょうか。それはセグメントツリーでは\displaystyle{a _ 1 ,a _ 2 ,\cdots , a _ {16} }の値だけでなく\displaystyle{a _ 1 }から\displaystyle{a _ 4 }まで計算した値というのも持っておくからです。先ほどの例のように16個の要素がある場合、\displaystyle{a _ 1 ,a _ 2 ,\cdots , a _ {16} }の値と合わせて31個の値を保持することが多いです。いっぱい覚えておくことでパパッと計算できるようにしているんですね。

実際\displaystyle{ a _ 3} から\displaystyle{ a _ 8}などある区間のものを足した値を索めようとなったとき、単純な配列だと6つのものを足していかないといけませんが、セグメントツリーでは\displaystyle{ a _ 3} から\displaystyle{ a _ 4}までを足したものと\displaystyle{ a _ 5} から\displaystyle{ a _ 8}を足したものはわかっているのでその2つを足すだけで済みます。嬉しいですね。

これだったら配列を使わないでずっとセグメントツリーを使えばいいじゃないかと思う人もいるかもしれません。しかし、\displaystyle{ a _ 3} から\displaystyle{ a _ 4}まで足したものなどを前もって計算しているのである区間のものを足した値を索めようとは思わないときには単純な配列を使った方がいいです。また、\displaystyle{ op(,) }二項演算子だとして\displaystyle{ op(a _ 1, op(a _ 2, a _ 3)) }\displaystyle{ op(op(a _ 1, a _ 2), a _ 3) }の計算結果が違うときは使えないので最強のデータ構造*1というわけではないです。

ちなみに\displaystyle{ op(a _ 1, op(a _ 2, a _ 3)) }\displaystyle{ op(op(a _ 1, a _ 2), a _ 3) }が同じ値をもってくれる、集合\displaystyle{S= {a _ 1,a _ 2 ,...  } }と演算\displaystyle{ op }の組(S,op)をモノイドって言ったり、演算はまあわかるでしょってことで集合\displaystyle{ S= {a _ 1,a _ 2 ,...  }}だけでモノイドって言ったり単位元e*2の存在も合わせた組(S,op,e)をモノイドって言ったりするらしいです。

セグメントツリーの実装

まず、いっぱい記憶できるように配列を用意します。

   type SegTree_t
        type(Data_t),allocatable::D(:)
        integer(16)::len
        integer(16)::leaf
    end type

要素の数が与えられるので最初に配列の長さを十分な長さにします。initはinitiationの略です。

    function SegTree_init(n)result(st)
        type(SegTree_t)::st
        integer(16),intent(in)::n
        integer(16)::x
        
        st%len=n
 
        x=1
        do while( x < n )
            x=2*x
        end do
        allocate(st%D(2*x-1), source=Data_e())
        !セグ木の要素を全て単位元に        
        st%leaf = x
    end function

要素の値が与えられたらそれをセットして、その要素に関連する値も計算するようにします。例えば \displaystyle{ a _ 4}が与えられたら\displaystyle{ op(a _ 3,a _ 4) }とか\displaystyle{ op(op(a _ 1 ,a _2) ,op(a _ 3,a _ 4)) }とかを計算しとかないといけませんね。言ってませんでしたが\displaystyle{ op(a _ 3,a _ 4) }\displaystyle{ a _ 4}の親と呼びます。

インデックスを2で割ると親の場所にいくようにするとうまいこといきます。そのためにいろいろインデックスをいじっているんですね

    subroutine SegTree_set(st,i,s)
        class(SegTree_t),intent(inout)::st
        integer(16),value:: i
        type(Data_t),intent(in)::s
        i = i + st%leaf - 1
        
        st%D(i) = s
 
        i=rshift(i,1)!親へ移動
        do while( i > 0)
            st%d(i) = op(st%D(i*2),st%d(i*2+1))!親の計算
            i=rshift(i,1)
        end do
    end subroutine

クエリに答えるのはこんな感じです。

    type(Data_t) function Segtree_query(st,ql,qr) result(ret)
        class(segtree_t),intent(inout):: st
        integer(16), intent(in):: ql,qr
        ret = Segtree_query_in(st,ql,qr,1_16,st%leaf,1_16)
    end function

    recursive type(Data_t)function Segtree_query_in(st,ql,qr,nl,nr,i) result(ret)
        implicit none
        class(Segtree_t),intent(inout):: st
        integer(16),intent(in):: ql,qr,nl,nr,i
        integer(16):: nm
        type(Data_t)::r1,r2
        if (nr < ql .or. qr < nl) then
            ret = Data_e()
        else if (ql <= nl .and. nr <= qr) then
            ret = st%d(i)
        else
            nm = (nl+nr)/2
            r1 = Segtree_query_in(st,ql,qr,nl,  nm,i*2  )
            r2 = Segtree_query_in(st,ql,qr,nm+1,nr,i*2+1)
            ret = op(r1,r2)
        end if
    end function

素数16の\displaystyle{a _ 1 ,a _ 2 ,\cdots , a _ {16} }\displaystyle{ a _ 3} から\displaystyle{ a _ 8}の和を知りたいときなどはまず、 Segtree_query_in(st,3,8,1,16,1)がSegtree_query_in(st,3,8,1,8,2),Segtree_query_in(st,3,8,9,16,3)を呼び出して、Segtree_query_in(st,3,8,1,8,2)がSegtree_query_in(st,3,8,1,4,4),Segtree_query_in(st,3,8,5,8,5)を呼び出して、Segtree_query_in(st,3,8,9,16,3)が0を返して、Segtree_query_in(st,3,8,1,4,4)がSegtree_query_in(st,3,8,1,2,8),Segtree_query_in(st,3,8,3,4,9)を呼び出してSegtree_query_in(st,3,8,5,8,5)が\displaystyle{ a _ 5 + a _ 6 + a _ 7 + a _ 8}を返して、Segtree_query_in(st,3,8,1,2,8)が0を返して、Segtree_query_in(st,3,8,3,4,9)が\displaystyle{ a _ 3 + a _ 4}を返してで結果が出てくる感じです。

今回の場合

今回は関数ax+bとcx+dの合成を考えます、合成するとc(ax+b)+d=ac x +bc +d となります。

計算順序によって変わらないか確認してみます。 \displaystyle{f _ 1 (x)= a _ 1 x + b _ 1} \displaystyle{f _ 2 (x)= a _ 2 x + b _ 2} \displaystyle{f _ 3 (x)= a _ 3 x + b _ 3} を合成します。

まず\displaystyle{f _ 1 (x)= a _ 1 x + b _ 1}\displaystyle{f _ 2 (x)= a _ 2 x + b _ 2}を合成すると \displaystyle{ a _ 1 a _ 2 x + b _ 1 a _ 2 +b _ 2} です。これと\displaystyle{f _ 3}を合成すると \displaystyle{f _ 2 (x)= a _ 1 a _ 2 a _ 3 x + b _ 1 a _ 2 a _ 3 + b _ 2 a _ 3 + b _ 3}です。

次にまず \displaystyle{f _ 2 (x)= a _ 2 x + b _ 2}\displaystyle{f _ 3 (x)= a _ 3 x + b _ 3} とを先に合成すると \displaystyle{ a _ 2 a _ 3 x + b _ 2 a _ 3 +b _ 3} です。これと\displaystyle{f _ 1}を合成すると \displaystyle{f _ 2 (x)= a _ 1 a _ 2 a _ 3 x + b _ 1 a _ 2 a _ 3 + b _ 2 a _ 3 + b _ 3}となり計算順序によらないこと、結合法則が満たされることがわかりました。というわけで安心してtype(Data_t)と演算子opを書きます。

module SegTree_setup_mod
    type Data_t
        !セグ木のノードが持つデータ
        !初期値は単位元に
        double precision::a=1,b=0
    end type
contains
    type(Data_t) function Data_e()
        Data_e=Data_t(1,0)
    end function
 
    type(Data_t) function op(d1,d2)
        type(Data_t),intent(in)::d1,d2
        op%a = d1%a * d2%a
        op%b = d1%b * d2%a + d2%b 
    end function
 
end module SegTree_setup_mod

これでセグメントツリーができました。

問題の実装

セグ木にNを載せるのはつらいので座標圧縮します。

    do i=1,M
     read*,p(i),a(i),b(i)
        P_ind(i)=i
    end do
    if(m/=0)call HeapsortPair(M,P,P_ind)
    P_new(1)=1
    do i=2,M
     if(P(i)/=P(i-1))Pcnt=Pcnt+1
        P_new(i)=Pcnt
    end do
    if(m/=0)call HeapSortPair(M,P_ind,P_new)

座標圧縮したあとは言われた通りに箱を変えて行って最大値と最小値を更新していきます

    Ans_min=1
    Ans_max=1
    do i=1,M
        call Segtree_set(ST,P_new(i),Data_t(a(i),b(i)))
        tmp=Segtree_query(ST,1_16,N)
        Ans_min=min(Ans_min,tmp%a+tmp%b)
        Ans_max=max(Ans_max,tmp%a+tmp%b)
    end do
    print*,Ans_min
    print*,Ans_max

おしまいです。ここで作ったセグメントツリーをどんどん使っていきたいですね。

全容

module SegTree_setup_mod
    type Data_t
        !セグ木のノードが持つデータ
        !初期値は単位元に
        double precision::a=1,b=0
    end type
contains
    type(Data_t) function Data_e()
        Data_e=Data_t(1,0)
    end function
 
    type(Data_t) function op(d1,d2)
        type(Data_t),intent(in)::d1,d2
        op%a = d1%a * d2%a
        op%b = d1%b * d2%a + d2%b 
    end function

end module SegTree_setup_mod
 
module SegTree_mod
    use SegTree_setup_mod
    implicit none
    type SegTree_t
        type(Data_t),allocatable::D(:)
        integer(16)::len
        integer(16)::leaf
    end type
contains
 
    function SegTree_init(n)result(st)
        type(SegTree_t)::st
        integer(16),intent(in)::n
        integer(16)::x
        
        st%len=n
 
        x=1
        do while( x < n )
            x=2*x
        end do
        allocate(st%D(2*x-1), source=Data_e())
        !セグ木の要素を全て単位元に        
        st%leaf = x
    end function
 
    subroutine SegTree_set(st,i,s)
        class(SegTree_t),intent(inout)::st
        integer(16),value:: i
        type(Data_t),intent(in)::s
        i = i + st%leaf - 1
        
        st%D(i) = s
 
        i=rshift(i,1)!親へ移動
        do while( i > 0)
            st%d(i) = op(st%D(i*2),st%d(i*2+1))!親の計算
            i=rshift(i,1)
        end do
    end subroutine
 
    type(Data_t) function Segtree_query(st,ql,qr) result(ret)
        class(segtree_t),intent(inout):: st
        integer(16), intent(in):: ql,qr
        ret = Segtree_query_in(st,ql,qr,1_16,st%leaf,1_16)
    end function
 
    recursive type(Data_t)function Segtree_query_in(st,ql,qr,nl,nr,i) result(ret)
        implicit none
        class(Segtree_t),intent(inout):: st
        integer(16),intent(in):: ql,qr,nl,nr,i
        integer(16):: nm
        type(Data_t)::r1,r2
        if (nr < ql .or. qr < nl) then
            ret = Data_e()
        else if (ql <= nl .and. nr <= qr) then
            ret = st%d(i)
        else
            nm = (nl+nr)/2
            r1 = Segtree_query_in(st,ql,qr,nl,  nm,i*2  )
            r2 = Segtree_query_in(st,ql,qr,nm+1,nr,i*2+1)
            ret = op(r1,r2)
        end if
    end function
    
    function SegTree_to_array(st) result(ret)
        class(SegTree_t):: st
        type(Data_t):: ret(st%len)
         ret(:) = st%d(st%leaf:st%leaf+st%len-1)
    end function
end module SegTree_mod

program main
 use SegTree_mod
    implicit none
    integer(16)::N,M
    integer(16),allocatable::p(:),P_ind(:),P_new(:)
    integer(16)::Pcnt=1
    double precision,allocatable::a(:),b(:)
    type(Data_t)::tmp
    type(SegTree_t)::ST
    double precision::Ans_min,Ans_max
    integer(16)::i
    read*,N,M
    allocate(p(M),a(M),b(M),P_ind(M),P_new(M))
    
    do i=1,M
     read*,p(i),a(i),b(i)
        P_ind(i)=i
    end do
    if(m/=0)call HeapsortPair(M,P,P_ind)
    P_new(1)=1
    do i=2,M
     if(P(i)/=P(i-1))Pcnt=Pcnt+1
        P_new(i)=Pcnt
    end do
    if(m/=0)call HeapSortPair(M,P_ind,P_new)
    
    ST=segtree_init(m)
    
    Ans_min=1
    Ans_max=1
    do i=1,M
        call Segtree_set(ST,P_new(i),Data_t(a(i),b(i)))
        tmp=Segtree_query(ST,1_16,N)
        Ans_min=min(Ans_min,tmp%a+tmp%b)
        Ans_max=max(Ans_max,tmp%a+tmp%b)
    end do
    print*,Ans_min
    print*,Ans_max
contains
subroutine HeapsortPair(n,array,array2)
  implicit none
!ここの入力は状況に応じて変更すること
  integer(16),intent(in) :: n
  integer(16),intent(inout) :: array(1:n),array2(1:n)
  integer(16)::i,k,j,l
  integer(16):: t,t2
 
  l=n/2+1
  k=n
  do while(k /= 1)
     if(l > 1)then
        l=l-1
        t=array(L)
        t2=array2(L)
     else
        t=array(k)
        t2=array2(k)
        array(k)=array(1)
        array2(k)=array2(1)
        k=k-1
        if(k == 1) then
           array(1)=t
           array2(1)=t2
           exit
        endif
     endif
     i=l
     j=l+l
     do while(j<=k)
        if(j < k)then
           if(array(j) < array(j+1))j=j+1
        endif
        if (t < array(j))then
           array(i)=array(j)
           array2(i)=array2(j)
           i=j
           j=j+j
        else
           j=k+1
        endif
     enddo
     array(i)=t
     array2(i)=t2
  enddo
  return
end subroutine HeapsortPair
end program main

*1:そんなものはない

*2:op(a,e)=aってなるやつ。足し算だったら0、掛け算だったら1みたいな