Normal view

There are new articles available, click to refresh the page.
Before yesterdayMain stream

Python:在内存受限的情况下处理超大型 JSON 文件

27 December 2022 at 16:25

本文翻译自 Processing large JSON files in Python without running out of memory

在处理的大型 JSON 文件的时候很容易发生内存爆掉的问题。即便原始数据力量上能够为内存所容纳,但是由于 Python 在内存中的记录方式,其内存消耗会比原始数据的体积要更大。如果遇到了开启 SWAP 的计算机,即便内存不爆掉,如果程序的运行内存进入缓冲区,也会导致运行速度的急剧下降。解决这个问题的方法是流式解析 (Stream parsing),也被称为 lazy parsing, iterative parsing,或者 chunked parsing。

1 问题:Python 的内存低效的 JSON 加载方式

考虑这样一个例子:一个大小是 24MB 的 JSON 文件,这个文件的内容代表了一系列的 Github 事件:

1
2
3
[{"id":"2489651045","type":"CreateEvent","actor":{"id":665991,"login":"petroav","gravatar_id":"","url":"https://api.github.com/users/petroav","avatar_url":"https://avatars.githubusercontent.com/u/665991?"},"repo":{"id":28688495,"name":"petroav/6.828","url":"https://api.github.com/repos/petroav/6.828"},"payload":{"ref":"master","ref_type":"branch","master_branch":"master","description":"Solution to homework and assignments from MIT's 6.828 (Operating Systems Engineering). Done in my spare time.","pusher_type":"user"},"public":true,"created_at":"2015-01-01T15:00:00Z"},
...
]

我们的目标是找到指定用交互过的仓库,下面这个简单的 Python 程序可以达成这一目标:

1
2
3
4
5
6
7
8
9
10
11
12
import json

with open("large-file.json", "r") as f:
data = json.load(f)

user_to_repos = {}
for record in data:
user = record["actor"]["login"]
repo = record["repo"]["name"]
if user not in user_to_repos:
user_to_repos[user] = set()
user_to_repos[user].add(repo)

程序运行结果是用户名和仓库名的映射字典。当我们使用 File memory profile 分析的时候,我们可以得到如下结果:

原文中这里是一个可交互的图表,建议在原文链接中查看

观察内存峰值,我们可以看到两处主要的内存分配行为:

  1. 读取文件;
  2. 将读取的内容转换成 Unicode 字符串。

我们来看 Python 的 json 模块的实现可以发现,这个标准库中的 json.load() 函数会先把整个文件读入内存。

1
2
3
4
5
6
7
def load(fp, *, cls=None, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None, object_pairs_hook=None, **kw):
"""Deserialize ``fp`` (a ``.read()``-supporting file-like object containing
a JSON document) to a Python object.
...
"""
return loads(fp.read(), ...)

注意上面记录到的是内存峰值的现象,所以后续创建字典对象时,其内存占用已经不是峰值处。整个程序执行过程中峰值是读取文件产生的。

有意思的是,尽管文件本身只有 24MB,但是读入内存之后其产生的内存峰值却远高于 24MB。为什么呢?

2 Python 的字符串内存表达方式

Python 的字符串表达经过优化可以使用较少的内存(这取决于字符串的内容)。首先,每个字符串都由于一个固有的开销 (overhead)。其次,如果字符串能够以 ASCII 编码表达,那么每个字符都只需要占用一个字节的内存。单如果有更多种类的字符需要表示,则每个字符占用的内存就上升到 4 个字节。我们来看下面的代码执行过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> import sys
>>> s = "a" * 1000
>>> len(s)
1000
>>> sys.getsizeof(s)
1049

>>> s2 = "❄" + "a" * 999
>>> len(s2)
1000
>>> sys.getsizeof(s2)
2074

>>> s3 = "💵" + "a" * 999
>>> len(s3)
1000
>>> sys.getsizeof(s3)
4076

这三个 case 中每个字符串的长度都是 1000,但是他们使用的内存大小是不同的,这与其内容有关。

3 流式解决方案

显而易见将整个文件加载进入内存是一种内存的浪费。如果文件的体积非常大,那么我们甚至无法将整个文件读入内存。

如果 JSON 文件是一些对象组成的列表,那么理论上我们可以分片进行加载。个很多 Python 库支持这种行为,这里我们采用 ijson 库。

1
2
3
4
5
6
7
8
9
10
11
import ijson

user_to_repos = {}

with open("large-file.json", "rb") as f:
for record in ijson.items(f, "item"):
user = record["actor"]["login"]
repo = record["repo"]["name"]
if user not in user_to_repos:
user_to_repos[user] = set()
user_to_repos[user].add(repo)

在之前的标准库版本中,当数据被读入内存之后,文件就会被关闭。但是在现在这个场景下文件必须被保持打开。 这是因为文件的内容只有部分被读取,后续内容的读取取决后续的迭代过程。

items() 接口在此处的接受一个查询字符串用来指定加载的对象。在这个例子中 "item" 输入表示返回每个顶层对象。你可以参见 ijson文档来查询接口细节。

采用上面的代码我们可以发现程序的内存峰值降低到了 3.6MB。

复刻在腾讯微博中的回忆

By: 胡中元
20 July 2018 at 21:47

大概是微博这个东西刚刚流行起来之时,也就是我初中的时候,我便用心的经营着我的腾讯微博,倒不是想要成为微博大咖,只是认为在同龄人坐在电脑前都只会打游戏时,我刷刷微博、发表一下自己的看法和见解,是更有意思的一件事。

然而腾讯微博迅速就被新浪微博超越,市场占有率几乎为 0 了。我自然也投靠了人多势众的新浪微博,但之前在腾讯微博中发的超过 1000 条微博是我的回忆 —— 中二青春。

我有一种预感,过不了多久腾讯微博就要被腾讯关停了,我可不能让之前写的那些碎碎念就这么消失,于是我用 Python 写了一个爬虫,将所有 [微博+图片+时间+转发微博+转发微博的所有信息] 都给爬到了本地数据库中,然后使用 React 做成了一个网站,名曰“复刻版腾讯微博”,将我发的微博放心地永远留在了自己的服务器中。

查看我的腾讯微博复刻网站,请点击:

https://hzy.pw/i/qqweibo/

## 基于服务器心情而工作的爬虫

截至目前,我的腾讯微博上共 1661 条微博,收听 65 人,听众 765 人。然而爬虫运行完毕之后获取到的微博数量为 1620,另外 41 条数据不翼而飞。我发布的微博和转发的微博中共包含了 1220 张图片,其中 6 张已被他们服务器丢失。微博中共包含 98 个视频,其中的 88 个均丢失(这是视频网站的锅,我们上传到优酷上的视频真的会被他们永远存放着吗,想想也是不可能的)。

微博中还包括了 785 条诸如 http://url.cn/482SZS 这样的短链接,其中 90% 均已失效,访问时直接提示 您访问的网址有误或该网址已过期 :( 此外,虽然 2011 年的微博也还给我留着,但所有微博的评论均没有了,数据被删掉了。。。

我想说的是,要是再不使用爬虫将这些宝贵的回忆取回,真说不定哪天就被腾讯给删掉了 ToT

讲真,各种复杂的情况都被我遇到了: 微博不提供 API,使用 Python 爬取 HTML 再解析,关键是 HTML 结构每次都会变,我花了很久很久的时间才适配了所有情况。另外服务器返回的数据并不可信,第一次得到的数据显示我在某一天发了 1 条微博,带有图片,再获取一次变成了发了 4 条,却无任何图片上传。(这不是腾讯为了防爬虫设计出来的,因为浏览器访问也是这样的,大概是腾讯微博在临死前,为了降低服务器负载而采用的拒绝式服务。。。)

于是我的爬虫在经过数天的完善后,拥有了应对前后数据不一致、连接握手失败、适应 HTTP 结构变化的功能。在此基础上又运行了四五天,才完成了爬取。因为对我那 1000 条微博的每一躺爬取,结果都是不一致的,直到最后连续运行十个小时也没爬出新数据后,我才认为是爬完了。

最终顺利爬取了能找到的所有数据,并存在了数据库里,真的是超级辛苦,让我激动的发了个微博(新浪微博~~)

数据清洗

数据清洗除了格式上的规范,还标记了一些重复的微博,这些微博在我的博客、空间里面重复,我的微博镜像站中没有必要包含这一部分内容。

此外为了制作微博镜像站,使用 Pillow 库将图片原图批量压缩成了 webp 格式的缩略图,在我的微博镜像站中,点击缩略图即可查看大图。 然而事实证明选择 webp 格式是错误的 ,虽然谷歌的 webp 格式拥有很高的压缩率,但是兼容性是个问题,不支持 Firefox、IE 和 iOS,几乎是只有 Chrome 能显示,所谓的 WebP JS 兼容性修复库其实是使用了 Flash 实现,然而后者本身就不值得使用。 所以说 WebP 格式的图片只适合客户端而不适合浏览器端。

最终我还是选择了 jpg 格式作为缩略图。毕竟我的服务器拥有 自动转换为 WebP 功能

愉悦的 React 开发体验

感谢 facebook/create-react-app 提供的脚手架,webpack+eslint+react 开发环境开箱即用。另外不得不感叹 React 的模块化使得逻辑相当清晰,很方便省心。

另外还要感谢 clean-blog CSS 主题lightgallery.js 图片灯箱插件

接下来

如果 QQ 空间、朋友圈、微博、豆瓣 这些网站在某一天宣布关停,我也会把自己的数据通通扒回本地,当我真心不希望这样,因为这个网站本身,就是一代回忆。

有空的话还要干几件事:试着统计下我发的微博中的一些有趣的数据,比如口头禅、文字情感之类的。再来就是把微博中的短链接替换成为长链接,因为正如上文提到的那样,很多短链接都在陆续失效了。

就酱。

现已完成,对我的腾讯微博的大数据统计挺有意思,请访问: https://hzy.pw/p/2569

回调之 Node.js VS 串行之 Python

By: 胡中元
12 April 2018 at 18:46

Node 与 Python,都是脚本语言,有着类似的使用场景,所以在各个地方早已经互相 pk、比较过无数回了。虽然我知道编程语言之间的 VS 是一个很 low 的行为,因为他们必定是各有优势的。但今天我还是特别的想说说自己的心得体会。

Node 是我曾经特别喜欢,也是非常熟练的编程语言。Python 我还处于学习阶段,不敢说深入了解。

使用场景

如果要评论手机 App 开发,自然是有着一堆框架的 Node 胜出,而如果站在科学计算领域,那胜利者绝对是 Python。所以本文主要还是对简单的日常场景进行对比 ———— 代替 Shell 的那些操作、作为服务器中间件与数据库打交道等等。

Python 脚本

前几个周我写了数个 Python 脚本,这便是其中一个。将 NAS 中我收集的漂亮壁纸的分辨率、时间信息记录到 SQLite 数据库中,再定时从数据库中随机取出近期未使用过的壁纸,让桌面壁纸换一换。

Python 对于这种事情确实在行。另外我根据爬虫创建的翻墙规则 Shadowrocket-ADBlock-Rules 也是使用 Python 生成的,5 线程爬虫协同工作,代码清晰简洁。

Node 脚本

Node 上面提到的那些,相对而言更适合较大的项目。比如我开发的 XSYU-GMS 教务系统,就是一个前后端全栈 Javascript 项目。

对于 web 中间件,例如监听 443 端口提供数据相关的 API 调用服务。Node 原生可作为一个守护进程保持运行,灵活维护一些常量(Python 也是这样,只是感觉工作量会更大一点),作为脚本语言的高维护性体现出来,并且 Node “以事务驱动单进程” 的特性在极限情况下可以实现超高的性能。

Node 成败均在回调

继续上文。虽说 Node “以事务驱动单进程” 的特性在极限情况下可以实现超高的性能。但是对于大多数情况下,Python 也是可以撑住的。低负载的情况下,Node 引以为傲的 “事务驱动” 就变成了可怕的 “回调地狱”,在性能提升不大的情况下,严重影响了开发效率,甚至是提高了程序复杂度所导致的出错率。

我曾经在用 Node 开发 RSS 爬虫时,认真思考每一步之间的相互依赖关系,哪些可以并行,哪些必须等待。然后精确地实现,代码看起来非常复杂。实际上的效率提升其实真不是很重要,全部都用串行实现,岂不一样可以达到目的。

毕竟,要最追求效率那我应该选择 C 甚至是汇编,既然选择了 Node 这样的脚本语言,那就是为了开发方便。 实际上,计算机编程语言的发展,就是在不断地用运行效率换取开发效率。当然,不会因为开发效率彻底代替运行效率。

Promise 与 Async Function

这大概就是拯救 Node 回调地狱的存在吧,确实,程序写起来体验好太多了。但是…… 这又导致了 await 地狱…… >o<

下面是我今天的代码,功能从数据库中获取部分 Email 地址,给他们发送邮件。

(async function() {
    let maxTimeStr = new Date().toUTCString();

    ret = await sqlP(`select username from user WHERE 
        d=0 AND status=1 AND created_at line.username);


    for(addr of emailList) {
        try {
            await sendMailForTK(addr);
        }
        catch(e) {
            console.log('sendMail failed', addr, e);
            await sleep(61);
            contine;
        }

        // send Done
        sqlP('update user SET d=88 WHERE username="' + addr + '"');

        await sleep(6);
    }

    process.exit();
})();

/**
三步串行操作:
1. 从数据库取得邮箱地址
2. 发送邮件
3,将记录写回数据库
*/

代码中的 sqlP(), sleep(), sendMailForTK(),都是我自己封装的 Promise Object,虽然封装不是个麻烦事,但我有一种感觉 —— 以后几乎所有的 Node 开发,都要经历这一步了。同时必须要经过的一步那就是在调用的时候用上 await,这就是 await 地狱。

于是问题来了,既然我们 80% 需要的程序逻辑都是串行,那么把这 80% 的代码特殊处理还真是一键麻烦的没必要的事情。像 Python 那样默认为串行,在 20% 的时候再使用多线程,在我看来是一种更合适、更通用的模式。

事件驱动是 JS 的核心,所以 Node 自然以后也只能保持这种模式了。

对 Node 的总结

我认为,Node 的最佳使用场景仅限于需要其特点:”异步回调” 的时候,可以带来高性能。而其他时候,这只会带来麻烦。

当然,Node 超级给力的模块机制,以及与简单易学的 JS 搭配实现的 web 开发语言前后端同一等等,使得 JS 永远充满魅力。

Python 的缺点

就并行串行而言,Python 我更偏好,但是这门语言我还是有一些需要吐槽的地方。

1. 编程风格

对于初学的我来说,完全不能赞同 Python 那些语法就是 “优雅的编程”,而 C 风格的大括号、分号就有多难看。在 class 里面,最多只能一个空行,而不能 2 个。使用 Python 编程让我有种被限制的感觉(也许还是因为我不熟悉吧)

总之我认为 “非 C 风格编程” 是一个缺点。

2. 相比 Node 包管理机制更弱

这不是 Python 太差,而是 Node 太叼了。Node 只提供底层 API,提倡使用第三方包完成你的工作(这也是导致 JS 项目层出不穷的原因),而 Python 并没有这种提倡。

让 Aria2 启动后自动继续未完成的下载 并清理已删除任务的文件

By: 胡中元
1 March 2018 at 19:23

这个假期,我做的最有趣的一件事就是将路由器改造成了一台稳定的 NAS,其中由 Aria2 实现的离线下载服务器是作为 NAS 的一个核心功能。用着非常方便,然而却有以下几个问题:

  1. 重启机器后,Aria2 在重启后并不会自动继续之前的下载。虽然保存了 sessions,但 Aria2 重启之后会自动将所有任务暂停。这就没法实现挂机下载了。
  2. 删除 Aria2 建立的下载任务后,并不会删除硬盘中对应的文件(包括只下载到一半的破损文件),这很不方便。


重要补充说明

我的代码依赖于 Aria2 编译时的 XML 库依赖,而在某些版本中是不带这个依赖的。所以本篇文章不一定适用于所有情况。

为了解决这 2 个问题,我编写了一个 Python 脚本,完美地解决了困扰。

脚本在 Python3 下运行正常,未对 Python2 测试。不依赖第三方模块。
为了实现 “让暂停的任务继续下载”,需要按照 Aria2 文档来调用 RPC,所以 需要在代码内修改相关的连接地址、密码等信息。

脚本同时会自动读取任务列表,并在下载目录找到所有不属于任务列表中的文件,删除之。
你也可以在 fileWhiteList 变量中设置不想要删除的文件的白名单。

#!/usr/bin/python
# -*- coding: UTF-8 -*-

# 1. start all paused tasks
# 2. delete other files on disk

# API: https://aria2.github.io/manual/en/html/aria2c.html#rpc-interface

from xmlrpc import client as xmlc
import os

rpcUrl = 'http://127.0.0.1:6800/rpc'
rpcToken = 'token:PASSWORD'
downloadPath = '/root/usb/nas/download/'  # same to aria2 config
fileWhiteList = ['/bypy', '/PROTECTED']   # while list for delete


s = xmlc.ServerProxy(rpcUrl)
api = s.aria2
# start all tasks
api.unpauseAll(rpcToken)


tasks = api.tellActive(rpcToken)
tasks += api.tellStopped(rpcToken, 0, 99)
tasks += api.tellWaiting(rpcToken, 0, 99)

for task in tasks:
    # started BT tasks
    if ('bittorrent' in task) and ('info' in task['bittorrent']):
        filename = task['bittorrent']['info']['name']
        fileWhiteList.append(filename)
    # other tasks
    else:
        for file in task['files']:
            path = file['path']
            if path.startswith('[METADATA]'):
                path = path.replace('[METADATA]', '')
            else:
                path = os.path.basename(path)

            fileWhiteList.append(path)

# del same items
fileWhiteList = set(fileWhiteList)

print('fileWhiteList', fileWhiteList)


def isStrContainItemInList(str, list):
    for item in list:
        if item in str:
            return True
    return False


for parent, dirnames, filenames in os.walk(downloadPath, topdown=False):
    for filename in filenames:
        path = os.path.join(parent, filename)
        if not isStrContainItemInList(path, fileWhiteList):
            os.remove(path)
            print('del file: ', filename)
    for dirname in dirnames:
        path = os.path.join(parent, dirname)
        if not isStrContainItemInList(path, fileWhiteList):
            try:
                os.rmdir(path)
                print('del dir:  ', dirname)
            finally:
                pass

一般来说,我们需要这段脚本在开机后自动运行,加入至 /etc/rc.local 即可:

sleep 1m && python /root/aria2/afterRun.py > /var/log/aria2.afterRun.log &

相关推荐

Aria2 bt-tracker 跟踪服务器列表自动更新:https://www.feng.ee/aria2-trackers-auto-update.html

TTRNN论文的UCF11实验复现

By: Erease
21 October 2019 at 08:00

对ICML2017上的使用TTRNN做视频分类的论文中的UCF11实验做了复现,作者在Github上公开了Python2代码,这里使用工具转换到Python3;另外代码缺少了预处理步骤,个人参考注释做了补充;在复现的过程中遇到并解决了部分问题:OpenCV提取视频帧序列,简单的并行处理,Tensorflow在Linux下CPU调度设置

Paper With code

Tensor-Train Recurrent Neural Networks for Video Classification

需要复现的部分是使用TT_RNN处理UCF11的数据集(使用新版本的UCF11数据集,旧版本的文件有点乱)

最开始还是梳理下论文和代码的逻辑

代码修改

通过注释和前后文猜测作者的想法…

2to3代码转换

本来在Linux环境下使用Python2是没什么问题的:

  • Windows下的Conda没有Python2的Tensorflow,而平时用Windows的居多
  • Python2和Python3的pickel导出的对象之间不兼容
  • 而在同一个Conda环境中,OpenCV和Keras之间也不兼容,所以运行代码的时候需要了两个Python2.7的环境,一个用于OpenCV处理视频,一个用于Keras
  • Github的Readme所提及的运行环境难以配置(版本问题,Conda的锅?)
  • Linux下计算只用了一个核心,这是不能接受的(训练到天荒地老)

这些显然是无故添加了很多麻烦的,在Conda环境下使用Python2成功运行后就尝试使用Pyhton3配置环境,尽管有些warning,但是一次性就解决了上述问题

2to3 - 自动将 Python 2 代码转为 Python 3 代码主要是转换TTRNN.py,其他的手动改就好(真的只转换了print…)

后端

默认使用了Tensorflow作为了Keras的后端,然而在Linux下默认只用了一个核心,还好使用Python3时给出了提示,按照提示找到了:Tips to Improve Performance for Popular Deep Learning Frameworks on CPUs

如果您有一个可以在内部并行化的操作,例如矩阵乘法(tf.matmul())或归约(例如tf.reduce_sum()),TensorFlow将通过在具有线程的线程池中调度任务来执行该intra_op_parallelism_threads操作。因此,此配置选项控制单个操作的最大并行加速。请注意,如果并行运行多个操作,则这些操作将共享此线程池。

如果TensorFlow图中有很多独立的操作-因为在数据流图中它们之间没有直接的路径-TensorFlow将尝试使用带有线程的线程池并发运行它们inter_op_parallelism_threads。如果这些操作具有多线程实现,则它们(在大多数情况下)将共享同一线程池以进行操作内并行操作。 这两个参数在Tensorflow的性能指南也有说明,其中提到了默认设置往往就有比较好的训练效果,但是为了缩短训练时间,经过几次测试之后选择调高:

#some option to improve performance in linux
import tensorflow as tf
from keras import backend as K

config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=18, inter_op_parallelism_threads=36, allow_soft_placement=True)
session = tf.compat.v1.Session(config=config)
K.set_session(session)

视频预处理

使用OpenCV提取视频帧序列,基本上做视频处理都要做这一步,因为逐个读写文件需要几分钟有点不太方便,看for循环部分可以设置并行就试了下,可以缩短运行时间到几秒(看硬盘速度~)

训练的代码Experiment_UCF11.py根据预处理的代码做了相应的修改,在导入训练和测试数据部分,因为导出视频帧使用的是list结构,而原作者的做法是导出的array,故在导入后多了一步list转array的操作

另外论文中提到的输入是RGB的通道矩阵,而OpenCV默认读取BGR,这里还用了matplot查看了下

最后导出的数据大小是16.4G

import os
import pickle
from multiprocessing.dummy import Pool as ThreadPool

import matplotlib.pyplot as plt
import numpy as np

import cv2

workspace ='./'
clips_path = workspace+'Datasets/UCF11_updated_mpg/'
frames_path =  workspace+'processed_data/'

if not os.path.isdir(frames_path):
    os.mkdir(frames_path)

classes = ['basketball', 'biking', 'diving', 'golf_swing', 'horse_riding', 'soccer_juggling',
           'swing', 'tennis_swing', 'trampoline_jumping', 'volleyball_spiking', 'walking']

def get_clips(class_name):
    files = os.listdir(clips_path + class_name)
    files.sort()
    clip_list = []
    for this_file in files:
        if '.DS_Store' not in this_file and 'Annotation' not in this_file:
            clips = os.listdir(clips_path + class_name + '/' + this_file)
            clips.sort()
            for this_clip in clips:
                if '.DS_Store' not in this_clip and 'Annotation' not in this_file:
                    clip_list.append( clips_path + class_name + '/' + this_file + '/' + this_clip )
    return clip_list

# iterate through all clips and store the length of each:
def process(par_input):
    item=par_input[1:]
    classes_name=par_input[0]
    for l in range(len(item)):
        #print(str(item[l]))
        cap = cv2.VideoCapture(item[l])
        ret = True
        clip_frames = []
        count = 0
        while(ret):
            k
            ret, frame = cap.read()
            if ret:
                #if count%2==0:
                rgb_frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
                frame_resized = cv2.resize(rgb_frame,(160,120))
                clip_frames.append(frame_resized)
                count = count + 1
        #data.append(clip_frames)
        class_index=classes.index(classes_name)
        clip_index_in_class=clips[class_index].index(item[l])
        head_index_of_class=int(sum(class_sizes[:class_index]))
        length_array_index=head_index_of_class+clip_index_in_class
        length_of_frames[length_array_index]=count
        save_name = str(classes_name) + '/' + str(l)
        if not os.path.isdir(frames_path + str(classes_name)):
            os.mkdir(frames_path +  str(classes_name))
        write_out = open(frames_path + save_name +'.pkl', 'wb')
        pickle.dump(clip_frames, write_out)
        write_out.close()

clips = [None]*11
#labels = [None]*11
class_sizes = np.zeros(11)
for k in range(11):
    class_clip_paths = get_clips(classes[k])
    clips[k] = class_clip_paths
    class_sizes[k] = len(class_clip_paths)
    #labels[k] = np.repeat([k], class_sizes[k])

n_all_clips=int(class_sizes.sum())
length_of_frames = np.zeros(n_all_clips)

par_input=[]
for i in range(11):
    temp_list=[]
    temp_list.append(classes[i])
    temp_list.extend(clips[i])
    par_input.append(temp_list)

pool = ThreadPool()
pool.map(process,par_input)
pool.close()
pool.join()

print("total "+str(length_of_frames.shape[0])+" clips")
print("The lengths of frame sequences is vary from: "+str(length_of_frames.min())+" to "+str(length_of_frames.max()))
print("The average length is:"+str(length_of_frames.mean()))

其中还统计了每个片段的长度,用于训练做截断用,大大减少训练的数据量(内存占用)和运算量(训练时间)

参数修改

a resolution of 320 X 240. We generate a sequence of RGB frames of size 160 X 120 from each clip at an fps(frame per second) of 24, corresponding to the standard value in film and television production. The lengths of frame sequences vary therefore between 204 to 1492 with an average of 483.7.

片段的帧序列长度和论文中的陈述有些出入,比实际统计的最大值900要大得多,论文使用的片段转换到24FPS,貌似需要使用ffmpeg做插值,感觉必要性不大(OpenCV可以跳帧,但是训练的结果不太好看)

源代码中默认的统一的输入长度是GLOBAL_MAX_LEN=1492,前面提到实际只有900,却给计算带了些希望,设置成900,内存占用只需要110G+,如果参考平均200左右将输入长度减半到450的话,仅仅需要50G+的内存就足够了,训练时间也可以缩短到一天左右

为了复现TTRNN在UCF11上的出彩的性能,选择使用TT,GRU是默认的

use_TT = 1      # 0 for non-TT, 1 for TT

速度最快的TT_MLP的帧提取部分貌似不太清楚,这里就不做了

迭代次数按照论文提到的100 epochs把iter_range设置为101(初始代码写的1001?),论文的结果,一个Epoch需要30分钟,一百个也就是50小时,两天,这个速度和输入长度为900的情形是接近的

运行

环境

Linux和Windows下都使用了Anaconda构建的Python3.7环境,主要安装最近版本的Keras,OpenCV,scikit-learn,其他的倚赖都会自动安装,windows下跑小数据集用于测试(内存太小了),Linux使用VSCode的SSH-Remote调试和运行(Tmux居然会随着VSCode断开而终止?)

论文貌似使用CPU跑的,这里也一样,把后端从Theano换成了Tensorflow,主要的瓶颈在于内存

命令

还是用Tmux挂在init进程下运行:tmux new -s UCF11_L450

python ***.py | tee UCF11.out

运行代码之后再切出来,从文件查看输出或者使用tmux a -t UCF11_L450切入

结果

个人不太了解调参,仅仅是让代码可以运行,半天时间可以跑完的一个设置:以隔帧采样的帧序列作为输入,取截断长度为200

运行情况

运行了几次,给出一个参数,占用,训练时长和效果的参考表格

  内存占用(GB) epoch_time(s) 100 Epochs(about) Epochs - Accuracy
跳帧-L200 30 120 4hour 99 - 0.4
L450 60 400 12hour 77 - 0.38
L900 130 1700 2days 14 - 0.25
❌
❌