diff --git a/finetune/blip.py b/finetune/blip/blip.py similarity index 98% rename from finetune/blip.py rename to finetune/blip/blip.py index 0f776e19..7851fb08 100644 --- a/finetune/blip.py +++ b/finetune/blip/blip.py @@ -10,8 +10,8 @@ warnings.filterwarnings("ignore") # from models.vit import VisionTransformer, interpolate_pos_embed # from models.med import BertConfig, BertModel, BertLMHeadModel -from vit import VisionTransformer, interpolate_pos_embed -from med import BertConfig, BertModel, BertLMHeadModel +from blip.vit import VisionTransformer, interpolate_pos_embed +from blip.med import BertConfig, BertModel, BertLMHeadModel from transformers import BertTokenizer import torch diff --git a/finetune/med.py b/finetune/blip/med.py similarity index 100% rename from finetune/med.py rename to finetune/blip/med.py diff --git a/finetune/med_config.json b/finetune/blip/med_config.json similarity index 100% rename from finetune/med_config.json rename to finetune/blip/med_config.json diff --git a/finetune/vit.py b/finetune/blip/vit.py similarity index 100% rename from finetune/vit.py rename to finetune/blip/vit.py diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 9896545a..5808051e 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -2,6 +2,7 @@ import argparse import glob import os import json +import random from PIL import Image from tqdm import tqdm @@ -9,20 +10,31 @@ import numpy as np import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from blip import blip_decoder +from blip.blip import blip_decoder # from Salesforce_BLIP.models.blip import blip_decoder DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def main(args): + # fix the seed for reproducibility + seed = args.seed # + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + if not os.path.exists("blip"): + cwd = os.getcwd() + print('Current Working Directory is: ', cwd) + os.chdir('finetune') + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") print(f"loading BLIP caption: {args.caption_weights}") image_size = 384 - model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./med_config.json") + model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) print("BLIP loaded") @@ -72,7 +84,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", + parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") parser.add_argument("--caption_extention", type=str, default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") @@ -84,6 +96,7 @@ if __name__ == '__main__': parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") + parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument("--debug", action="store_true", help="debug mode") args = parser.parse_args()