pix2pixHD train your own dataset (win10)

pix2pixHD train your own dataset (win10)

1, Environmental requirements

  • Linux or macOS or win10(me win10+anaconda)
  • Python 2 or 3(me python3.6)
  • NVIDIA GPU (11g memory or larger) + CUDA cudnn (GPU of me cuda9.2,12g memory can train 1024x512 images, but can't train or test 2048x1024 images)

2, Environment configuration

3, Source download

  • Download address: https://github.com/NVIDIA/pix2pixHD
  • Download method 1: github directly downloads zip, which is slow to download
  • Download method 2: with gitee( https://gitee.com )Platform download (registration required), will return to the normal download speed, that is, first import the code plus sign to the gitee platform and then download.

4, Make your own dataset

1. Image input size:

                        . Preprocess resize_ Or_ The default setting for crop is scale_width, loadsize is set to 1024 by default, that is, the width of all training images is scaled to (1024) while maintaining the aspect ratio. If you want other settings, use – resize_ Or_ Change the cross option. For example, reset_ Or_ Crop set to scale_width_and_crop is to adjust the image to have width first opt.loadSize , and then cut the size randomly( opt.fineSize , opt.fineSize ). –resize_ Or_ When cross is set to cross, only random clipping is performed. If you don't want to do any preprocessing, specify none, which does nothing but make sure that the image is divisible by 32.

2. Data set with instance:

                        , N-1, where N is the number of labels). This is because you need to generate a heat vector from the label graph. The specific data set can be made by referring to the datasets/cityscapes sample data set in the code.

3. Data set without instance:

                   _ A,train_B,test_ A and test_B four folders, a mapping relationship from a to B for training. The image file names in a and B should correspond one by one, such as train_A label picture 000000.jpg corresponds to true value picture train_ 000000.jpg. In B

5, Training and testing (me, no_ instance,win10)

   Windows system does not have sh file, so it uses Python training test directly.
The official procedure is train.py and test.py Multithreading is enabled outside the main function of the file, while Python multithreading must be enabled in the built-in main function under windows, so train.py and test.py It needs to be modified, otherwise an error is reported as follows:

        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

1. Training procedure train.py

import time
import os
import numpy as np
import torch
from torch.autograd import Variable
from collections import OrderedDict
from subprocess import call
import fractions
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer

def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
            start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int)
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.niter = 1
        opt.niter_decay = 0
        opt.max_dataset_size = 10

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D],
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:

            # update discriminator weights
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:

if __name__=='__main__':

2. Training instructions

python train.py --name project1024 --label_nc 0 --no_instance --gpu_ids 0

3. Test procedure test.py

import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import torch

def test():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
        elif opt.data_type == 8:

        if opt.verbose:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx, verbose=True)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
            generated = model.inference(data['label'], data['inst'], data['image'])

        visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                               ('synthesized_image', util.tensor2im(generated.data[0]))])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)


if __name__=='__main__':

4. Test instructions:

python test.py --name project1024 --ngf 64 --label_nc 0 --no_instance --how_many 464

   the ngf size needs to be modified when the image input size is different, how_many is the number of test images, the default is 50.

5. Test results display

   the official test program will generate the images folder and the index.html The corresponding test image and generated image are saved in the images folder index.html The images in the images folder will be displayed as a table in the file. Unfortunately, the truth image is missing. After the test, you can use the following program to display in HTML file in the form of label syn real, similar to the following form.

Procedure img2 Html.py :

import dominate
from dominate.tags import *
import os
import argparse
import glob

parse.add_argument('--test_B',type=str,default='web/test_B',help='Path to the test_B folder')
parse.add_argument('--ima', dest='ima', type=str, default='web/images', help='Path to the images folder')
parse.add_argument('--outDir',type=str,default='./',help='Path to the output folder')

class HTML:
    def __init__(self, title, refresh=0):
        self.title = title
        self.doc = dominate.document(title=title)
        if refresh > 0:
            with self.doc.head:
                meta(http_equiv="refresh", content=str(refresh))

    def add_header(self, str):
        with self.doc:

    def add_table(self, border=1):
        self.t = table(border=border, style="table-layout: fixed;")

    def add_images(self, ims, txts, links, width=512):
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            if txt=='real':
                                with a(href=os.path.join(args.test_B, link)):
                                    img(style="width:%dpx" % (width), src=os.path.join(args.test_B, im))
                                with a(href=os.path.join(args.ima, link)):
                                    img(style="width:%dpx" % (width), src=os.path.join(args.ima, im))

    def save(self,outDir):
        # html_file = '%s/index.html' % self.web_dir
        html_file = outDir+'/testResult.html'
        f = open(html_file, 'wt')

def getJpgFile(pathDir):
    if os.path.exists(pathDir)==True:
        Jpg = os.path.join(pathDir, '*.jpg')
        JpgFile = glob.glob(Jpg)
        return JpgFile
        print('{0}Jpg Image folder does not exist'.format(pathDir))
        return ''
def Img2HtmlForm(images):
    html = HTML('test_html')
    html.add_header('test result')
    # Get the list of all Jpg files under the specified path
    jpgFiles = getJpgFile(images)
    if len(jpgFiles) > 0:
        for jpgFile in jpgFiles:
            ims = []
            txts = []
            links = []
            if 'input' in jpgFile:
                jpgFile = os.path.split(jpgFile)[1]
                html.add_images(ims, txts, links)
        print('{0}Lower none Jpg image'.format(images))

if __name__ == '__main__':

6. Other parameters of training and testing

   refer to the parameter settings in the py file under the options folder.

Tags: Python github Windows Linux

Posted on Fri, 05 Jun 2020 03:39:09 -0400 by gmccague