001package com.fs.starfarer.api.util;
002
003import java.util.ArrayList;
004import java.util.Collection;
005import java.util.Collections;
006import java.util.Comparator;
007import java.util.HashMap;
008import java.util.List;
009import java.util.Map;
010import java.util.Random;
011
012import org.json.JSONArray;
013import org.json.JSONException;
014
015import com.fs.starfarer.api.campaign.econ.MarketAPI;
016
017public class WeightedRandomPicker<T> implements Cloneable {
018
019        @Override
020        public WeightedRandomPicker<T> clone() {
021                try {
022                        WeightedRandomPicker<T> copy = (WeightedRandomPicker<T>) super.clone();
023                        copy.items = new ArrayList<T>(items);
024                        copy.weights = new ArrayList<Float>(weights);
025                        return copy;
026                } catch (CloneNotSupportedException e) {
027                        return null;
028                }
029        }
030
031        private List<T> items = new ArrayList<T>();
032        transient private List<Float> weights = new ArrayList<Float>();
033        private String w;
034        
035        private float total = 0f;
036        private final boolean ignoreWeights;
037        
038        private Random random = null;
039        
040        public WeightedRandomPicker() {
041                this(false);
042        }
043        
044        public WeightedRandomPicker(boolean ignoreWeights) {
045                this.ignoreWeights = ignoreWeights;
046        }
047
048        public WeightedRandomPicker(Random random) {
049                this(false);
050                this.random = random;
051        }
052        
053        Object readResolve() {
054                try {
055                        weights = new ArrayList<Float>();
056                        if (w != null) {
057                                JSONArray arr = new JSONArray(w);
058                                for (int i = 0; i < arr.length(); i++) {
059                                        weights.add((float)arr.getDouble(i));
060                                }
061                        }
062                } catch (JSONException e) {
063                        throw new RuntimeException(e);
064                }
065                return this;
066        }
067        
068        Object writeReplace() {
069                JSONArray arr = new JSONArray();
070                for (Float f : weights) {
071                        arr.put(f);
072                }
073                w = arr.toString();
074                
075                return this;
076        }
077
078
079        public void clear() {
080                items.clear();
081                weights.clear();
082                total = 0;
083        }
084
085//      public void addAll(List<T> items) {
086//              for (T item : items) {
087//                      add(item);
088//              }
089//      }
090        
091        public void addAll(Collection<T> items) {
092                for (T item : items) {
093                        add(item);
094                }
095        }
096        
097        public void addAll(WeightedRandomPicker<T> other) {
098                for (int i = 0; i < other.items.size(); i++) {
099                        add(other.items.get(i), other.weights.get(i));
100                }
101        }
102        
103        public void add(T item) {
104                add(item, 1f);
105        }
106        public void add(T item, float weight) {
107                //if (weight < 0) weight = 0;
108                if (weight <= 0) return;
109                items.add(item);
110                weights.add(weight); // + (weights.isEmpty() ? 0 : weights.get(weights.size() - 1)));
111                total += weight;
112        }
113        
114        public void remove(T item) {
115                if (item == null) return;
116                int index = items.indexOf(item);
117                if (index != -1) {
118                        items.remove(index);
119                        float weight = weights.remove(index);
120                        total -= weight;
121                }
122        }
123        
124        public boolean isEmpty() {
125                return items.isEmpty();
126        }
127        
128        public List<T> getItems() {
129                return items;
130        }
131        
132        public float getWeight(T item) {
133                int index = items.indexOf(item);
134                if (index < 0) return 0;
135                return getWeight(index);
136        }
137        
138        public float getWeight(int index) {
139                return weights.get(index);
140        }
141        public void setWeight(int index, float weight) {
142                float w = getWeight(index);
143                weights.set(index, weight);
144                total += weight - w;
145        }
146        
147        public T getItemWithHighestWeight() {
148                float maxW = 0;
149                for (int i = 0; i < items.size(); i++) {
150                        float w = getWeight(i);
151                        if (w > maxW) maxW = w;
152                }
153                if (maxW <= 0) return null;
154                
155                WeightedRandomPicker<T> other = new WeightedRandomPicker<>();
156                for (int i = 0; i < items.size(); i++) {
157                        float w = getWeight(i);
158                        if (w >= maxW) {
159                                other.add(items.get(i));
160                        }
161                }
162                return other.pick();
163        }
164
165        public T pickAndRemove() {
166                T pick = pick();
167                remove(pick);
168                return pick;
169        }
170        
171        public T pick(Random random) {
172                Random orig = this.random;
173                this.random = random;
174                T pick = pick();
175                this.random = orig;
176                return pick;
177        }
178        
179        public T pick() {
180                if (items.isEmpty()) return null;
181                
182                if (ignoreWeights) {
183                        int index;
184                        if (random != null) {
185                                index = (int) (random.nextDouble() * items.size());
186                        } else {
187                                index = (int) (Math.random() * items.size());
188                        }
189                        return items.get(index);
190                }
191                
192                float random;
193                if (this.random != null) {
194                        random = this.random.nextFloat() * total;
195                } else {
196                        random = (float) (Math.random() * total);
197                }
198                if (random > total) random = total;
199                //random = 0.1f;
200                //random = total - 0.001f;
201                float weightSoFar = 0f;
202                int index = 0;
203                for (Float weight : weights) {
204                        weightSoFar += weight;
205                        if (random <= weightSoFar) break;
206                        index++;
207                }
208                return items.get(Math.min(index, items.size() - 1));
209        }
210
211        public Random getRandom() {
212                return random;
213        }
214
215        public void setRandom(Random random) {
216                this.random = random;
217        }
218
219        
220        
221        public void print(String title) {
222                System.out.println(title);
223                
224                Map<T, Integer> indices = new HashMap<T, Integer>();
225                for (int i = 0; i < items.size(); i++) {
226                        T item = items.get(i);
227                        indices.put(item, i);
228                }
229                
230                List<T> sorted = new ArrayList<T>(items);
231                Collections.sort(sorted, new Comparator<T>() {
232                        public int compare(T o1, T o2) {
233                                return o1.toString().compareTo(o2.toString());
234                        }
235                });
236                
237                for (T item : sorted) {
238                        int index = indices.get(item);
239                        float weight = weights.get(index);
240                        //String percent = Misc.getRoundedValueMaxOneAfterDecimal((weight / total) * 100f) + "%";
241                        String percent = "" + (int)((weight / total) * 100f) + "%";
242                        
243                        //System.out.println("    " + item.toString() + ": " + percent + " (" + Misc.getRoundedValue(weight) + ")");
244                        String itemStr = "";
245                        if (item instanceof MarketAPI) {
246                                itemStr = ((MarketAPI)item).getName();
247                        } else {
248                                itemStr = item.toString();
249                        }
250                        System.out.println(String.format("    %-30s%10s%10s", itemStr, percent, Misc.getRoundedValue(weight)));
251                }
252                //System.out.println("  Total: " + (int) total);
253                //System.out.println();
254        }
255
256        public float getTotal() {
257                return total;
258        }
259        
260}
261
262
263
264
265
266