该系列的一大特色是其采用的Hook机制。同样作为计算机视觉算法框架,相比于简单易懂、新手友好的Gluon-CV,Hook机制无疑增加了初学者入门级工具箱系列的难度。一个自然的问题是,为什么要引入Hook机制呢?
事实上,AOP中的Hook机制是面向切面编程(AOP)编程思想的体现。从抽象的层面上来说,问为什么要引入Hook机制,实际上就是在问软件开发中为什么要采用面向切面编程的设计模式?
我个人粗略的理解是,之所以提出面向方面编程,是为了解决面向对象编程(OOP)代码重复的问题。面向对象编程的思想是分配职责,将功能分散到不同的对象类中,在不同的类中设计不同的方法。如果两个类A和B需要使用同一个方法,那么可以将该方法写在一个独立的类C中,然后两个类继承这个类C。但是这样做有两个问题。
第一个问题是,对于没有多重继承特性的语言,比如Java,如果继承C类,就无法继承其他类。如果C类中的功能不是A类和B类的主要功能,那么通过继承C类的方法来获取它是行不通的。简单粗暴的解决办法就是分别在A类和B类中实现C类中的子功能。这样的话,完全相同的代码存在于两个地方,代码的重复度大大增加。如果要修改这个方法,就必须修改这两个地方。两个地方都可以,但是如果有 m 个这样的情况,每个情况重复 n 个地方怎么办?
第二个问题是,即使可以继承,A类和B类又和C类耦合在一起,如果有一个D类与C类有类似但不同的子功能,我希望A类和B类可以配置通过用户选项动态选择是调用C类还是D类中的子函数,因此这种直接继承的方案无法提供这种动态选择的灵活性。
本质上,除了继承之外,面向对象编程所追求的封装特性切断了类之间的联系和共享。然而,为了减少代码的重复,提高软件的模块化水平,需要将分散在各个类别中的重复代码统一起来,两者之间存在矛盾。
这种在程序运行时动态地将所需代码切入类的指定方法和指定位置的编程思想就是面向切面编程。其中,提取出来的需要被几个类调用的代码片段称为切面。它们在程序运行时会被切分成指定类的指定方法。切入的类和方法称为切入。观点。面向切面编程可以让我们将与当前业务逻辑无关的部分抽离到单独的一层,实现非侵入式的功能扩展。
正是通过Hook机制,该系列可以对网络实现、算法训练和测试过程进行抽象和解耦,从而实现相当高的模块化程度,即重复代码的数量大大减少。
2.Hook机制的工作流程
Hook机制其实并不是特例。这只是我第一次看到它,因为我的编码经验太少了。钩子编程()是一个计算机编程术语,是指通过拦截软件模块之间的函数调用、消息传递和事件传递来修改或扩展操作系统、应用程序或其他软件组件的程序执行过程。其中处理拦截的函数调用、事件、消息的代码称为钩子,这应该就是上面提到的AOP编程的方面。
其中,Hook机制是由类(例如)和HOOK类(例如)配合完成的,它们共同构成了一套训练框架的架构规范。
首先,在 中,负责网络训练和测试整个过程的类定义了训练和测试周期中的一系列触发器,如下所示:
# 省略 ... self.call_hook('before_train_epoch') for i, data_batch in enumerate(self.data_loader): # 省略 ... self.call_hook('before_train_iter') # 省略 ... self.call_hook('after_train_iter') # 省略 ... self.call_hook('after_train_epoch')
其次,在Hook类以及与该类配合的子类中,还定义了一堆与上述类的触发器中的步骤/次/节点同名的函数, , , , , ,称为hook函数,如下所示:
class Hook: def before_run(self, runner): pass def after_run(self, runner): pass def before_epoch(self, runner): pass def after_epoch(self, runner): pass def before_iter(self, runner): pass def after_iter(self, runner): pass # ... 省略
当然,上面的Hook类是最原始的实现,也就是基本没有实现任何功能。如果想要定义一些操作,实现一些功能,可以继承这个类,自定义我们需要的功能,比如mmcv..hooks。模块中的类继承了最原始的Hook类,基本实现了里面的子功能。点击;以及 mmseg.core 中的类。 进一步继承了前面的类并重写了 和 两个子功能。
配合类和Hook类,当类实例运行到特定时刻时,会通过触发函数调用各个Hook类中的钩子函数,完成特定的功能。例如,在每个或每隔几个触发时刻,可以通过调用该函数来完成设置。
个人感觉这个Hook机制很像通信系统中的轮流查询机制。它是一套训练框架规范,规定了算法生命周期中的各种操作。之所以起作用,是因为在类的被调用方法中,每个节点指定了调用对应钩子函数的操作。训练过程中,类会轮流请求端口,即依次调用每个节点的钩子函数。如果专门定制了对应的钩子函数,则执行该函数。如果不是,则为空函数,直接传递,继续下一步,从而实现模块间函数调用、消息传递、事件传递的拦截,从而修改或扩展组件的行为。
3、Hook机制的底层实现
明确了类和Hook类共同实现Hook机制的工作流程后,还剩下两个问题。第一个问题是,如何让类实例知道调用特定Hook类实例的子函数,即如何将类实例与Hook类实例关联起来?第二个问题是,一个类实例可能调用多个Hook对象,每个Hook对象都会有自己的同名子函数。例如,这种情况如何处理?
对于第一个问题,HOOK类实例是通过类函数注册到类实例中的。让我们举个例子。训练模型时,会调用mmseg.apis模块的函数。其中两个步骤是注册钩子和类实例的钩子:
runner.register_training_hooks(cfg) runner.register_hook(eval_hook(val_dataloader, eval_cfg))
该类提供了两种注册钩子的方法:
- 方法是直接传入一个实例化的HOOK对象,插入到self中。类实例列表;
- fg方法是传入一个配置项cfg,根据配置项实例化HOOK对象,然后插入到self中。列表。
其实第二种方法就是先调用mmcv。方法生成实例化的HOOK对象,然后调用第一个方法将实例化的HOOK对象插入到self.列表。
与自我。包含已注册的 Hook 类实例的列表,该类在运行时调用已注册的 Hook 类实例的子函数是合乎逻辑的。看一下类中函数的定义,其中传入了self.('')。 (hook, )(self) 其实就是调用self。列表中钩子对象命名的函数,例如类实例的方法。至此,第一个问题,如何将想要的Hook类实例的某个方法动态插入到类实例的运行过程中就已经实现了。
def call_hook(self, fn_name): """Call all hooks. Args: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". """ for hook in self._hooks: getattr(hook, fn_name)(self)
关于第二个问题,从上面函数的定义也可以看出,实例的run函数运行过程中,在设置该函数的每个节点,都会依次执行self。列表方法中所有钩子实例中对应的时间。例如,此时,它是遍历所有钩子实例的方法。如果只有一个Hook实例重写了这个方法,而其他实例的方法都是pass的,那也没关系。但如果有两个或多个实例的方法实现不通过,那么这就涉及到一个应该先调用哪个实例的方法的问题。具体来说,在程序中,每个Hook实例都插入在self的位置之前和之后。列表,因为函数是按顺序调用的。
优先级在注册钩子时就已实现,并且是默认变量。从下面函数的定义可以看出,对于一个新注册的Hook实例,根据其指定的优先级,如果不指定,则默认优先级为'',并插入到self.中,优先级越高这将是。如果新注册的Hook实例的优先级与已有的Hook实例的优先级相同,则按照先到先得的原则,先到先得。至此,第二个问题也解决了。
def register_hook(self, hook, priority='NORMAL'): """Register a hook into the hook list. The hook will be inserted into a priority queue, with the specified priority (See :class:`Priority` for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered. Args: hook (:obj:`Hook`): The hook to be registered. priority (int or str or :obj:`Priority`): Hook priority. Lower value means higher priority. """ assert isinstance(hook, Hook) if hasattr(hook, 'priority'): raise ValueError('"priority" is a reserved attribute for hooks') priority = get_priority(priority) hook.priority = priority # insert the hook to a sorted list inserted = False for i in range(len(self._hooks) - 1, -1, -1): if priority >= self._hooks[i].priority: self._hooks.insert(i + 1, hook) inserted = True break if not inserted: self._hooks.insert(0, hook)
4. mmseg 中的钩子
下图中,我整理了mmseg的tools/train.py整个运行周期中会用到的所有hook对应的具体Hook类以及调用的对应次数。
另外,以这些Hook的调用时间为例,整理一下对应的优先级(顺序)。