Update seg_train.py

main
rajeevyasarla 6 years ago committed by GitHub
parent a11634acbd
commit 3416e36a1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -63,6 +63,7 @@ parser.add_argument('--netG', default='', help="path to netG (to continue traini
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints')
parser.add_argument('--modeclean', type=int,default= 1, help='segmentation network training mode, by it is default trained using clean images')
parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs')
parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot')
opt = parser.parse_args()
@ -245,6 +246,10 @@ ganIterations = 0
count = 1
Best_Fs = 0
Best_epoch = 0
if opt.modeclean == 1:
Num_rn = 0
else:
Num_rn = 34
for epoch in range(1000):
if epoch%60 == 0 and epoch>0:
opt.lrG = opt.lrG/1.25
@ -288,7 +293,7 @@ for epoch in range(1000):
y1 = int((160-opt.imageSize)/2)
input_cpu = input_cpu.numpy()
target_cpu = target_cpu.numpy()
if (random.randint(0,100)<0) :
if (random.randint(0,100)<Num_rn) :
for j in range(10):
index = random.randint(0,24000)
@ -303,6 +308,8 @@ for epoch in range(1000):
input_cpu = input_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
target_cpu = target_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
target_cpu[target_cpu>0.49]=1
target_cpu[target_cpu<=0.49]=0
input_cpu = torch.from_numpy(input_cpu)
target_cpu = torch.from_numpy(target_cpu)
target_cpu, input_cpu = target_cpu.float().cuda(), input_cpu.float().cuda()

Loading…
Cancel
Save