Unigramトークナイザの最大トークン長と最大語彙数は係り受け解析に影響するのか

タイ語UD_Thai-CorporaによるDeBERTaモデル

Unigramトークナイザにおける最大トークン長Mと最大語彙数Vが、UPOS/LAS/MLASにどう影響するか調査した。DeBERTaモデルの製作には、th_lst-ud-train1.conlluとth_lst-ud-train2.conlluの各文だけを用いている。

th_lst-ud-dev.conlluで評価

V=4000V=8000V=16000V=32000
M=3 90.75/81.91/75.50 91.04/81.52/75.20 90.38/79.95/73.85 90.70/81.22/74.45
M=4 91.90/84.20/77.75 92.12/82.92/76.41 93.13/83.67/77.52 92.66/82.84/76.41
M=5 91.25/83.08/76.54 90.56/82.05/75.23 91.61/81.84/75.61 91.63/82.17/76.29
M=6 90.94/82.91/75.82 92.71/84.91/79.01 92.72/83.21/76.42 91.71/82.85/75.93
M=7 91.64/83.34/77.20 92.25/83.29/76.86 92.13/83.17/76.12 92.02/83.74/77.19
M=8 91.68/82.96/77.10 92.28/84.45/77.94 92.42/83.93/77.75 91.76/83.00/76.98
M=9 91.78/83.15/76.66 92.02/84.43/78.37 93.10/84.25/78.40 92.68/84.03/78.14
M=10 91.22/82.27/75.82 92.64/83.91/78.32 92.60/83.42/77.86 92.73/82.91/77.40
M=11 91.78/83.38/76.77 92.19/84.50/78.63 92.15/84.36/78.27 92.33/84.82/78.62
M=12 90.83/82.34/75.96 92.35/83.96/78.24 91.59/82.83/76.20 92.62/84.54/78.52

th_lst-ud-test.conlluでテスト

V=4000V=8000V=16000V=32000
M=3 88.63/74.44/67.37 87.28/72.98/65.37 86.65/71.70/64.01 86.62/74.33/67.55
M=4 89.24/77.40/69.36 89.00/74.67/67.61 88.08/73.40/66.02 88.86/75.25/68.78
M=5 90.16/78.44/71.55 89.11/77.94/70.80 90.53/75.49/68.60 88.07/75.78/68.96
M=6 88.57/76.70/69.02 89.70/76.63/70.52 90.34/78.29/72.70 89.37/77.29/69.92
M=7 89.37/75.92/69.43 90.98/77.58/72.48 90.05/77.27/70.74 89.49/78.61/72.21
M=8 89.03/76.05/69.89 90.80/78.20/72.31 90.09/77.51/72.07 88.71/77.42/70.55
M=9 89.53/78.56/72.07 90.07/79.19/73.50 90.75/75.46/69.46 89.34/78.46/72.61
M=10 88.60/78.50/71.30 91.89/78.74/73.49 90.09/77.84/73.07 89.60/75.05/69.93
M=11 89.65/77.48/71.60 90.31/78.22/72.00 90.92/80.09/74.42 90.48/77.75/71.67
M=12 88.23/74.66/69.01 91.45/80.93/74.51 90.63/77.89/72.26 90.03/76.83/72.56

作業環境

mdx 1GPU (NVIDIA A100-SXM4-40GB)

/bin/shスクリプト

#! /bin/sh
# pip3 uninstall pytokenizations
URL=https://github.com/KoichiYasuoka/spaCy-Thai
D=`basename $URL`/UD_Thai-Corpora
test -d $D || git clone --depth=1 $URL
S='{u=u$0"\n";if($0==""){if(u~/\t0\troot\t/)printf("%s",u);u=""}}'
for F in train dev test
do nawk "$S" $D/*-$F*.conllu > $F.conllu
   sed -n 's/^# text = //p' $F.conllu > $F.txt
done
cat $D/*-train?.conllu | tee train.upos | sed -n 's/^# text = //p' > train.txt
S='{if(NF==10&&$1~/^[1-9][0-9]*$/)printf($1>1?" %s":"%s",$2);if(NF==0)print}'
nawk -F'\t' "$S" train.upos > token.txt
U=http://universaldependencies.org/conll18/conll18_ud_eval.py
C=`basename $U`
test -f $C || curl -LO $U
for M in 3 4 5 6 7 8 9 10 11 12
do for V in 4000 8000 16000 32000
   do test -d deberta$M-$V || python3 -c m,v=$M,$V'
from transformers import (DataCollatorForLanguageModeling,TrainingArguments,
  DebertaV2TokenizerFast,DebertaV2Config,DebertaV2ForMaskedLM,Trainer)
from tokenizers import (Tokenizer,models,pre_tokenizers,normalizers,processors,
  decoders,trainers)
import json,unicodedata
s=["[CLS]","[PAD]","[SEP]","[UNK]","[MASK]"]
spt=Tokenizer(models.Unigram())
spt.pre_tokenizer=pre_tokenizers.Sequence([pre_tokenizers.Whitespace(),
  pre_tokenizers.Punctuation()])
spt.normalizer=normalizers.Sequence([normalizers.Nmt(),normalizers.NFKC()])
spt.post_processor=processors.TemplateProcessing(single="[CLS] $A [SEP]",
  pair="[CLS] $A [SEP] $B:1 [SEP]:1",special_tokens=[("[CLS]",0),("[SEP]",2)])
spt.decoder=decoders.WordPiece(prefix="",cleanup=True)
spt.train(trainer=trainers.UnigramTrainer(vocab_size=v,max_piece_length=m,
  special_tokens=s,unk_token="[UNK]",n_sub_iterations=2),files=["token.txt"])
spt.save("tokenizer.json")
with open("tokenizer.json","r",encoding="utf-8") as r:
  spt=json.load(r)
spt["model"]["vocab"]=[t for t in spt["model"]["vocab"] if len(t[0])<2 or
  unicodedata.category(t[0][0])!="Mn" and int((ord(t[0][-1])-1)/7)!=521]
with open("tokenizer.json","w",encoding="utf-8") as w:
  json.dump(spt,w,ensure_ascii=False,indent=2)
tkz=DebertaV2TokenizerFast(tokenizer_file="tokenizer.json",split_by_punct=True,
  do_lower_case=False,keep_accents=True,bos_token="[CLS]",cls_token="[CLS]",
  pad_token="[PAD]",sep_token="[SEP]",unk_token="[UNK]",mask_token="[MASK]",
  vocab_file="/dev/null",model_max_length=512)
t=tkz.convert_tokens_to_ids(s)
cfg=DebertaV2Config(hidden_size=768,num_hidden_layers=12,num_attention_heads=12,
  intermediate_size=3072,relative_attention=True,position_biased_input=False,
  pos_att_type=["p2c","c2p"],max_position_embeddings=tkz.model_max_length,
  vocab_size=len(tkz),tokenizer_class=type(tkz).__name__,
  bos_token_id=t[0],pad_token_id=t[1],eos_token_id=t[2])
arg=TrainingArguments(num_train_epochs=8,per_device_train_batch_size=24,
  output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2)
class ReadLineDS(object):
  def __init__(self,file,tokenizer):
    self.tokenizer=tokenizer
    with open(file,"r",encoding="utf-8") as r:
      self.lines=[s.strip() for s in r if s.strip()!=""]
  __len__=lambda self:len(self.lines)
  __getitem__=lambda self,i:self.tokenizer(self.lines[i],truncation=True,
    add_special_tokens=True,max_length=self.tokenizer.model_max_length-2)
trn=Trainer(args=arg,data_collator=DataCollatorForLanguageModeling(tkz),
  model=DebertaV2ForMaskedLM(cfg),train_dataset=ReadLineDS("train.txt",tkz))
trn.train()
trn.save_model("deberta{}-{}".format(m,v))
tkz.save_pretrained("deberta{}-{}".format(m,v))'
      U=upos$M-$V
      if [ ! -d $U ]
      then E=esupar.train
	   python3 -m $E deberta$M-$V $U 24 /tmp train.upos
           python3 -m $E $U $U 24 /// train.conllu dev.conllu test.conllu
      fi
      test -f result$M-$V/result && continue
      mkdir -p result$M-$V
      for F in dev test
      do cat $F.txt | python3 -c 'mdl,f="'$U'","result'$M-$V/$F'.conllu"
import esupar
nlp=esupar.load(mdl)
with open(f,"w",encoding="utf-8") as w:
  while True:
    try:
      doc=nlp(input().strip())
    except:
      quit()
    print(doc,file=w)'
      done
      ( echo '***' $U dev
        python3 $C -v dev.conllu result$M-$V/dev.conllu
        echo '***' $U test
        python3 $C -v test.conllu result$M-$V/test.conllu
      ) | tee result$M-$V/result
   done
done