HfArgumentParser是transformers.hf_argparser.py内的一个类,进一步封装了python的argparse模块,主要处理命令行的参数。
代码解读
hf_argparser.py中首先使用typing模块下的NewType类定义了两个新的数据类型,NewType的第一个参数是数据类型的名称,该数据类型具体可以是Any类型(即各种类型)。
DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
init函数解读
classHfArgumentParser(ArgumentParser):"""
argparse.ArgumentParser的子类,使用数据类(dataclasses)来生成参数。
这个类被设计成可以很好地与python的argparse配合使用。特别是,您可以在初始化之后向解析器添加更多(非数据类支持的)参数,并且您将在解析后获得作为附加名称空间的输出。可选:要创建子参数组,请在数据类中使用' _argument_group_name '属性。
"""
dataclass_types: Iterable[DataClassType]# 定义一个DataClassType类型的可迭代对象def__init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]],**kwargs):"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
kwargs (`Dict[str, Any]`, *optional*):
Passed to `argparse.ArgumentParser()` in the regular way.
"""# To make the default appear when using --helpif"formatter_class"notin kwargs:
kwargs["formatter_class"]= ArgumentDefaultsHelpFormatter
super().__init__(**kwargs)if dataclasses.is_dataclass(dataclass_types):# 判断dataclass_types是否是一个数据类
dataclass_types =[dataclass_types]# 是,则将数据类变为一个数据类列表
self.dataclass_types =list(dataclass_types)for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
该类在
__init__
的初始化阶段将datalass_types参数传递进来的各种数据类的参数(属性)注册到argparse中,并返回argparse.ArgumentParser对象——解析器。
***使用例子:
parser = HfArgumentParser(_TRAIN_ARGS)
parse_dict函数解读
defparse_dict(self, args: Dict[str, Any], allow_extra_keys:bool=False)-> Tuple[DataClass,...]:"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
types.
Args:
args (`dict`):
dict containing config values
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
unused_keys =set(args.keys())# 去重
outputs =[]for dtype in self.dataclass_types:# parse中注册的数据类
keys ={f.name for f in dataclasses.fields(dtype)if f.init}
inputs ={k: v for k, v in args.items()if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)ifnot allow_extra_keys and unused_keys:# 如果不允许有额外的参数但命令行传入了额外的参数的情况下,丢出错误raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")returntuple(outputs)# 返回parse中注册数据类对象元组
使用参数字典args中的各种参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
***使用例子:
parser.parse_dict(args)
parse_yaml_file函数
defparse_yaml_file(
self, yaml_file: Union[str, os.PathLike], allow_extra_keys:bool=False)-> Tuple[DataClass,...]:"""
Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
dataclass types.
Args:
yaml_file (`str` or `os.PathLike`):
File name of the yaml file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)returntuple(outputs)
读取yaml文件中的各种参数值,使用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:
iflen(sys.argv)==2and sys.argv[1].endswith(".yaml"):# 如果sys.argv得到的参数个数等于2,说明命令只有py文件名和配置文件名return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
parse_json_file函数
defparse_json_file(
self, json_file: Union[str, os.PathLike], allow_extra_keys:bool=False)-> Tuple[DataClass,...]:"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
Args:
json_file (`str` or `os.PathLike`):
File name of the json file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""withopen(Path(json_file), encoding="utf-8")as open_json_file:
data = json.loads(open_json_file.read())
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)returntuple(outputs)
读取json文件中的各种参数值,使用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
使用例子:
iflen(sys.argv)==2and sys.argv[1].endswith(".json"):# 如果sys.argv得到的参数个数等于2,说明命令只有py文件名和配置文件名return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
parse_arg_into_dataclasses
defparse_args_into_dataclasses(
self,
args=None,
return_remaining_strings=False,
look_for_args_file=True,
args_filename=None,
args_file_flag=None,)-> Tuple[DataClass,...]:
从命令行命令中获取参数,并用获取到的参数值创建并初始化parse中注册的各个数据类,并返回数据类对象元组。
版权归原作者 溟玖 所有, 如有侵权,请联系我们删除。