藏字閣

閱讀時間約 3 分鐘

1099 字

Python line profiler 是一個很方便的套件,讓你很方便看到程式碼逐行執行的時間,用法可以參考拙作關於 Python profiling 的介紹。有一個致命的缺點就是不知道 multiprocess 的 profiling,Github 上也有一個 2016 年留到現在的 issue。我在這裡提供一個 hacky 的作法在 multiprocessing 下使用 line profiler。

是在一個日文網站看到的作法,雖然我看不懂日文,不過日本人寫的 Python 跟我一樣,所以我一樣可以改來用。

原本作法

假設我們今天要做的 multiprocess 是算每個數字的 1.5 次方的總和,為了說明方便,就用了不是很有效率的寫法如下:

import math
import multiprocessing as mp

def power3(num):
    return num * num * num

def child(num):
    num = math.sqrt(num)
    num = power3(num)
    return num

pool = mp.Pool()
total = 0
for res in pool.imap_unordered(child, list(range(5))):
    total += res
print(total)
pool.close()
pool.join()

使用 child function 裡面呼叫 power3 平行地計算 1 到 5 的 1.5 次方再加總。如果在兩個函式上面加上 @profile 做 profiling 會產生以下結果:

Timer unit: 1e-06 s

Total time: 0 s
File: q.py
Function: power3 at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     4                                           @profile
     5                                           def power3(num):
     6                                               return num * num * num

Total time: 0 s
File: q.py
Function: child at line 8

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     8                                           @profile
     9                                           def child(num):
    10                                               num = math.sqrt(num)
    11                                               num = power3(num)
    12                                               return num

什麼都沒有!

對個別 subprocesss 做 profiling

我們有一個解決辦法,就是對個別 subprocesss 做 profiling:

import math
import multiprocessing as mp
from line_profiler import LineProfiler


def wrap(num):
    prof = LineProfiler()
    prof.add_function(child)
    prof.add_function(power3)
    res = prof.runcall(child, num)
    # prof.print_stats()
    prof.dump_stats('demo_%d.lprof' % num)
    return res


def power3(num):
    return num * num * num


def child(num):
    num = math.sqrt(num)
    num = power3(num)
    return num


pool = mp.Pool()
timings = dict()
unit = None
total = 0
for res in pool.imap_unordered(wrap, list(range(5))):
    total += res
print(total)
pool.close()
pool.join()

執行完後,我們用 prof.dump_stats('demo_%d.lprof' % num) 存下每個 subprocess 的紀錄,也可以用 print_stats 印出結果,不過因為是不同 process 同時執行,所以紀錄會混在一起,沒辦法看。存下來以後就可以使用 python -m line_profiler demo_0.lprof 看到 subprocess 逐行的執行時間 (demo_0.lprof 可以換成其他檔案),以下是其中一個:

>> python -m line_profiler demo_0.lprof 
Timer unit: 1e-06 s

Total time: 1e-06 s
File: q.py
Function: power3 at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    16                                           def power3(num):
    17         1          1.0      1.0    100.0      return num * num * num

Total time: 1.4e-05 s
File: q.py
Function: child at line 20

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    20                                           def child(num):
    21         1         11.0     11.0     78.6      num = math.sqrt(num)
    22         1          3.0      3.0     21.4      num = power3(num)
    23         1          0.0      0.0      0.0      return num

這邊 hits 只有一,表示只紀錄到一個 subprocess 的時間。

合併多個 subprocess 的紀錄

我們這邊開了五個 subprocess,所以有五個 .lprof 檔案,有沒有辦法合在同一個檔案呢?可以使用 prof.get_stats() 拿出 profiling 的結果,觀察內容以後,很容易就可以把多個結果合併起來,然後再用 line_profiler 提供的 show_text 函式寫到檔案:

import math
import multiprocessing as mp
from line_profiler import LineProfiler, show_text


def wrap(num):
    prof = LineProfiler()
    prof.add_function(child)
    prof.add_function(power3)
    res = prof.runcall(child, num)
    return res, prof.get_stats()


def power3(num):
    return num * num * num


def child(num):
    num = math.sqrt(num)
    num = power3(num)
    return num


def merge_timing(timings, new_timings):
    for key, stats in new_timings.items():
        if key not in timings:
            timings[key] = stats
        else:
            cur_dict = dict((lineno, (hit, ts)) for lineno, hit, ts in timings[key])
            for lineno, hit, ts in stats:
                hit_now, ts_now = cur_dict.get(lineno, (0, 0))
                cur_dict[lineno] = (hit_now + hit, ts_now + ts)
            cur = sorted([(lineno, hit, ts) for lineno, (hit, ts) in cur_dict.items()])
            timings[key] = cur


pool = mp.Pool()
timings = dict()
unit = None
total = 0
for res, r in pool.imap_unordered(wrap, list(range(5))):
    total += res
    merge_timing(timings, r.timings)
    unit = r.unit
print(total)
show_text(timings, unit, stream=open('prof.txt', 'w'))
pool.close()
pool.join()

prof.txt 的結果如下:

Timer unit: 1e-06 s

Total time: 3e-06 s
File: qq.py
Function: power3 at line 14

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    14                                           def power3(num):
    15         5          3.0      0.6    100.0      return num * num * num

Total time: 5.8e-05 s
File: qq.py
Function: child at line 18

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    18                                           def child(num):
    19         5         47.0      9.4     81.0      num = math.sqrt(num)
    20         5         10.0      2.0     17.2      num = power3(num)
    21         5          1.0      0.2      1.7      return num

可以看到 hits 都是五,表示我們將五個 subprocess 的結果都合併起來了。至於開 pull request 到原本的 repo 呢?就留給更厲害的大大了!

參考資料

line_profiler with multiprocessing by floatnflow

comments powered by Disqus

最新文章

分類

標籤