砂場で遊ぼう

c++/python/mathematica などの練習帳

Project Euler 196 / 三つ子素数

問題概略

正の整数すべてを使って下の図のような三角形を作る。

f:id:variee:20211219033551p:plain

各数は最大 8 個の数と隣接している。

次の条件をみたす 3 つの素数の組を「三つ子素数」(prime triplet)と呼ぶ。

「3 つの素数のうち 1 つが他の 2 つと隣接している」

たとえば 2 行目の 2 と 3 は三つ子素数の要素である。
8 行目は 29 と 31 が三つ子素数の要素であり,9 行目は 37 だけが三つ子素数の要素である。

三つ子素数の要素のうち  n 行目にあるものの総和を  S(n) とおく。
 S(8)=60,  S(9)=37 であり  S(10000)=950007619 である。
 S(5678027) + S(7208785) を求めよ。

https://projecteuler.net/problem=196

解説の pdf も作りました。きれいなレイアウトで読みたい方はこちらをどうぞ。

drive.google.com

基本方針

上から  i 行目,左から  j 列目の数を  f(i,\, j) とおきます。
群数列の要領で数えると一般項は次のようになります。

 f(i,\, j)=\frac{1}{2}(i-1)i+j

素数判定は isprime するだけです。

三つ子素数の判定は union-find(dsu)でやりました。
隣接する 2 つの素数を merge して同じグループに入れていきます。 n 行目の素数のうち大きさ 3 以上のグループに属するものの総和が  S(n) です。

実際の計算

この解法では merge の回数を減らすのが時間短縮の鍵です。
すぐわかるのは「左右は調べなくてよい」です。
いま考えている素数は 3 以上の素数です。
左右は 2 より大きい偶数で,それは素数ではありません。

では上下はどうでしょうか?
 i=n=\text{(奇数)} とおいて真上,真下の項との差を調べるとわかります。

\begin{align*}
&f(n,\, j)-f(n-1,\,j)=i-1=\text{(偶数)}\\
&f(n+1,\, j)-f(n,\,j)=i=\text{(奇数)}
\end{align*}

 f(n,\,j) が奇素数のとき  f(n-1,\, j) は奇数なので素数の可能性がありますが, f(n+1,\, j) は偶数なので素数ではありません。
同様に  n 行目の素数を中心とする  5\times 5 の領域の偶奇を調べると,素数の可能性があるのは図の  \bigcirc の部分だけだとわかります。

f:id:variee:20211219034138p:plain

次はこれらをどう調べるかです。
たとえば  f(n,\, j) の真上が素数でなかったら  n-2 行を調べるのは無駄です。
下図の「A」→「B1→B2」→「C1→C2」の順に調べました。図中の  \bigcirc はこのサイクルでは調べません。

f:id:variee:20211219034208p:plain

C1 は A の 2 つ隣の数から始まるサイクルでも調べる可能性がありますが, n 行目の素数の割合は約  3\% しかなく,影響はないと判断しました。

実装

python のコードは次の通りです。dsu は ACL(AtCoder Library) の python 移植版を使いました。計算時間は約 50 秒でした。
github.com

    import time
    from atcoder.dsu import DSU
    from sympy import isprime
    
    def nums(i, j):
        return (i - 1) * i // 2 + j if 1 <= j <= i else 0
    
    def solve(n):
        g = DSU(6 * n) 
        '''
        辺で結べるnums(i,j)のインデックスをgにmergeする。
        衝突を避けるためn-2,n-1,n,n+1,n+2行のjに
        それぞれ0,n,2n,3n,4nを足したものをmergeした。
        '''
        lst1 = [i for i in range(n + 1) if isprime(nums(n, i))]
        lst2 = []
        for j in lst1:  # n行目から辺をのばす
            if isprime(nums(n - 1, j)):
                g.merge(j + 2 * n, j + n)  # n行目とn-1行目
                for k in [-1, 1]:
                    if isprime(nums(n - 2, j + k)):
                        g.merge(j + n, j + k)  # n-1行目とn-2行目
            for k in [-1, 1]:
                if isprime(nums(n + 1, j + k)):
                    g.merge(j + 2 * n, j + k + 3 * n)  # n行目とn+1行目
                    lst2.append(j + k)
        for j in lst2:  # n+2行目から辺をのばす
            if isprime(nums(n + 2, j)):
                g.merge(j + 3 * n, j + 4 * n)  # n+1行目とn+2行目
    
        lst3 = [i for i in lst1 if g.size(i + 2 * n) >= 3]
        return sum(map(lambda i: nums(n, i), lst3))
    
    if __name__ == '__main__':
        start = time.time()
        print(solve(5678027) + solve(7208785))
        elapsed_time = time.time() - start
        print("elapsed_time: {0} [sec]".format(elapsed_time))