|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +#/************************************************************ |
| 4 | +#* |
| 5 | +#* Licensed to the Apache Software Foundation (ASF) under one |
| 6 | +#* or more contributor license agreements. See the NOTICE file |
| 7 | +#* distributed with this work for additional information |
| 8 | +#* regarding copyright ownership. The ASF licenses this file |
| 9 | +#* to you under the Apache License, Version 2.0 (the |
| 10 | +#* "License"); you may not use this file except in compliance |
| 11 | +#* with the License. You may obtain a copy of the License at |
| 12 | +#* |
| 13 | +#* http://www.apache.org/licenses/LICENSE-2.0 |
| 14 | +#* |
| 15 | +#* Unless required by applicable law or agreed to in writing, |
| 16 | +#* software distributed under the License is distributed on an |
| 17 | +#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 18 | +#* KIND, either express or implied. See the License for the |
| 19 | +#* specific language governing permissions and limitations |
| 20 | +#* under the License. |
| 21 | +#* |
| 22 | +#*************************************************************/ |
| 23 | + |
| 24 | +import os, sys |
| 25 | +import numpy as np |
| 26 | + |
| 27 | +current_path_ = os.path.dirname(__file__) |
| 28 | +singa_root_=os.path.abspath(os.path.join(current_path_,'../..')) |
| 29 | +sys.path.append(os.path.join(singa_root_,'thirdparty','protobuf-2.6.0','python')) |
| 30 | +sys.path.append(os.path.join(singa_root_,'tool','python')) |
| 31 | + |
| 32 | +from model import neuralnet, updater |
| 33 | +from singa.driver import Driver |
| 34 | +from singa.layer import * |
| 35 | +from singa.model import save_model_parameter, load_model_parameter |
| 36 | +from singa.utils.utility import swap32 |
| 37 | + |
| 38 | +from PIL import Image |
| 39 | +import glob,random, shutil, time |
| 40 | +from flask import Flask, request, redirect, url_for |
| 41 | +from singa.utils import kvstore, imgtool |
| 42 | +app = Flask(__name__) |
| 43 | + |
| 44 | +def train(batchsize,disp_freq,check_freq,train_step,workspace,checkpoint=None): |
| 45 | + print '[Layer registration/declaration]' |
| 46 | + # TODO change layer registration methods |
| 47 | + d = Driver() |
| 48 | + d.Init(sys.argv) |
| 49 | + |
| 50 | + print '[Start training]' |
| 51 | + |
| 52 | + #if need to load checkpoint |
| 53 | + if checkpoint: |
| 54 | + load_model_parameter(workspace+checkpoint, neuralnet, batchsize) |
| 55 | + |
| 56 | + for i in range(0,train_step): |
| 57 | + |
| 58 | + for h in range(len(neuralnet)): |
| 59 | + #Fetch data for input layer |
| 60 | + if neuralnet[h].layer.type==kDummy: |
| 61 | + neuralnet[h].FetchData(batchsize) |
| 62 | + else: |
| 63 | + neuralnet[h].ComputeFeature() |
| 64 | + |
| 65 | + neuralnet[h].ComputeGradient(i+1, updater) |
| 66 | + |
| 67 | + if (i+1)%disp_freq == 0: |
| 68 | + print ' Step {:>3}: '.format(i+1), |
| 69 | + neuralnet[h].display() |
| 70 | + |
| 71 | + if (i+1)%check_freq == 0: |
| 72 | + save_model_parameter(i+1, workspace, neuralnet) |
| 73 | + |
| 74 | + |
| 75 | + print '[Finish training]' |
| 76 | + |
| 77 | + |
| 78 | +def product(workspace,checkpoint): |
| 79 | + |
| 80 | + print '[Layer registration/declaration]' |
| 81 | + # TODO change layer registration methods |
| 82 | + d = Driver() |
| 83 | + d.Init(sys.argv) |
| 84 | + |
| 85 | + load_model_parameter(workspace+checkpoint, neuralnet,1) |
| 86 | + |
| 87 | + app.debug = True |
| 88 | + app.run(host='0.0.0.0', port=80) |
| 89 | + |
| 90 | + |
| 91 | +@app.route("/") |
| 92 | +def index(): |
| 93 | + return "Hello World! This is SINGA DLAAS! Please send post request with image=file to '/predict' " |
| 94 | + |
| 95 | +def allowed_file(filename): |
| 96 | + allowd_extensions_ = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif']) |
| 97 | + return '.' in filename and \ |
| 98 | + filename.rsplit('.', 1)[1] in allowd_extensions_ |
| 99 | + |
| 100 | +@app.route('/predict', methods=['POST']) |
| 101 | +def predict(): |
| 102 | + size_=(32,32) |
| 103 | + pixel_length_=3*size_[0]*size_[1] |
| 104 | + label_num_=10 |
| 105 | + if request.method == 'POST': |
| 106 | + file = request.files['image'] |
| 107 | + if file and allowed_file(file.filename): |
| 108 | + im = Image.open(file).convert("RGB") |
| 109 | + im = imgtool.resize_to_center(im,size_) |
| 110 | + pixel = floatVector(pixel_length_) |
| 111 | + byteArray = imgtool.toBin(im,size_) |
| 112 | + data = np.frombuffer(byteArray, dtype=np.uint8) |
| 113 | + data = data.reshape(1, pixel_length_) |
| 114 | + #dummy data Layer |
| 115 | + shape = intVector(4) |
| 116 | + shape[0]=1 |
| 117 | + shape[1]=3 |
| 118 | + shape[2]=size_[0] |
| 119 | + shape[3]=size_[1] |
| 120 | + |
| 121 | + for h in range(len(neuralnet)): |
| 122 | + #Fetch data for input layer |
| 123 | + if neuralnet[h].is_datalayer: |
| 124 | + if not neuralnet[h].is_label: |
| 125 | + neuralnet[h].Feed(data,3) |
| 126 | + else: |
| 127 | + neuralnet[h].FetchData(1) |
| 128 | + else: |
| 129 | + neuralnet[h].ComputeFeature() |
| 130 | + |
| 131 | + #get result |
| 132 | + #data = neuralnet[-1].get_singalayer().data(neuralnet[-1].get_singalayer()) |
| 133 | + #prop =floatArray_frompointer(data.mutable_cpu_data()) |
| 134 | + prop = neuralnet[-1].GetData() |
| 135 | + print prop |
| 136 | + result=[] |
| 137 | + for i in range(label_num_): |
| 138 | + result.append((i,prop[i])) |
| 139 | + |
| 140 | + result.sort(key=lambda tup: tup[1], reverse=True) |
| 141 | + print result |
| 142 | + response="" |
| 143 | + for r in result: |
| 144 | + response+=str(r[0])+":"+str(r[1]) |
| 145 | + |
| 146 | + return response |
| 147 | + return "error" |
| 148 | + |
| 149 | + |
| 150 | +if __name__=='__main__': |
| 151 | + |
| 152 | + if sys.argv[1]=="train": |
| 153 | + if len(sys.argv) < 6: |
| 154 | + print "argv should be more than 6" |
| 155 | + exit() |
| 156 | + if len(sys.argv) > 6: |
| 157 | + checkpoint = sys.argv[6] |
| 158 | + else: |
| 159 | + checkpoint = None |
| 160 | + #training |
| 161 | + train( |
| 162 | + batchsize = int(sys.argv[2]), |
| 163 | + disp_freq = int(sys.argv[3]), |
| 164 | + check_freq = int(sys.argv[4]), |
| 165 | + train_step = int(sys.argv[5]), |
| 166 | + workspace = '/workspace', |
| 167 | + checkpoint = checkpoint, |
| 168 | + ) |
| 169 | + else: |
| 170 | + if len(sys.argv) < 3: |
| 171 | + print "argv should be more than 2" |
| 172 | + exit() |
| 173 | + checkpoint = sys.argv[2] |
| 174 | + product( |
| 175 | + workspace = '/workspace', |
| 176 | + checkpoint = checkpoint |
| 177 | + ) |
| 178 | + |
0 commit comments