001package com.fs.starfarer.api.util; 002 003import java.util.ArrayList; 004import java.util.Collection; 005import java.util.Collections; 006import java.util.Comparator; 007import java.util.HashMap; 008import java.util.List; 009import java.util.Map; 010import java.util.Random; 011 012import org.json.JSONArray; 013import org.json.JSONException; 014 015import com.fs.starfarer.api.campaign.econ.MarketAPI; 016 017public class WeightedRandomPicker<T> implements Cloneable { 018 019 @Override 020 public WeightedRandomPicker<T> clone() { 021 try { 022 WeightedRandomPicker<T> copy = (WeightedRandomPicker<T>) super.clone(); 023 copy.items = new ArrayList<T>(items); 024 copy.weights = new ArrayList<Float>(weights); 025 return copy; 026 } catch (CloneNotSupportedException e) { 027 return null; 028 } 029 } 030 031 private List<T> items = new ArrayList<T>(); 032 transient private List<Float> weights = new ArrayList<Float>(); 033 private String w; 034 035 private float total = 0f; 036 private final boolean ignoreWeights; 037 038 private Random random = null; 039 040 public WeightedRandomPicker() { 041 this(false); 042 } 043 044 public WeightedRandomPicker(boolean ignoreWeights) { 045 this.ignoreWeights = ignoreWeights; 046 } 047 048 public WeightedRandomPicker(Random random) { 049 this(false); 050 this.random = random; 051 } 052 053 Object readResolve() { 054 try { 055 weights = new ArrayList<Float>(); 056 if (w != null) { 057 JSONArray arr = new JSONArray(w); 058 for (int i = 0; i < arr.length(); i++) { 059 weights.add((float)arr.getDouble(i)); 060 } 061 } 062 } catch (JSONException e) { 063 throw new RuntimeException(e); 064 } 065 return this; 066 } 067 068 Object writeReplace() { 069 JSONArray arr = new JSONArray(); 070 for (Float f : weights) { 071 arr.put(f); 072 } 073 w = arr.toString(); 074 075 return this; 076 } 077 078 079 public void clear() { 080 items.clear(); 081 weights.clear(); 082 total = 0; 083 } 084 085// public void addAll(List<T> items) { 086// for (T item : items) { 087// add(item); 088// } 089// } 090 091 public void addAll(Collection<T> items) { 092 for (T item : items) { 093 add(item); 094 } 095 } 096 097 public void addAll(WeightedRandomPicker<T> other) { 098 for (int i = 0; i < other.items.size(); i++) { 099 add(other.items.get(i), other.weights.get(i)); 100 } 101 } 102 103 public void add(T item) { 104 add(item, 1f); 105 } 106 public void add(T item, float weight) { 107 //if (weight < 0) weight = 0; 108 if (weight <= 0) return; 109 items.add(item); 110 weights.add(weight); // + (weights.isEmpty() ? 0 : weights.get(weights.size() - 1))); 111 total += weight; 112 } 113 114 public void remove(T item) { 115 if (item == null) return; 116 int index = items.indexOf(item); 117 if (index != -1) { 118 items.remove(index); 119 float weight = weights.remove(index); 120 total -= weight; 121 } 122 } 123 124 public boolean isEmpty() { 125 return items.isEmpty(); 126 } 127 128 public List<T> getItems() { 129 return items; 130 } 131 132 public float getWeight(T item) { 133 int index = items.indexOf(item); 134 if (index < 0) return 0; 135 return getWeight(index); 136 } 137 138 public float getWeight(int index) { 139 return weights.get(index); 140 } 141 public void setWeight(int index, float weight) { 142 float w = getWeight(index); 143 weights.set(index, weight); 144 total += weight - w; 145 } 146 147 public T getItemWithHighestWeight() { 148 float maxW = 0; 149 for (int i = 0; i < items.size(); i++) { 150 float w = getWeight(i); 151 if (w > maxW) maxW = w; 152 } 153 if (maxW <= 0) return null; 154 155 WeightedRandomPicker<T> other = new WeightedRandomPicker<>(); 156 for (int i = 0; i < items.size(); i++) { 157 float w = getWeight(i); 158 if (w >= maxW) { 159 other.add(items.get(i)); 160 } 161 } 162 return other.pick(); 163 } 164 165 public T pickAndRemove() { 166 T pick = pick(); 167 remove(pick); 168 return pick; 169 } 170 171 public T pick(Random random) { 172 Random orig = this.random; 173 this.random = random; 174 T pick = pick(); 175 this.random = orig; 176 return pick; 177 } 178 179 public T pick() { 180 if (items.isEmpty()) return null; 181 182 if (ignoreWeights) { 183 int index; 184 if (random != null) { 185 index = (int) (random.nextDouble() * items.size()); 186 } else { 187 index = (int) (Math.random() * items.size()); 188 } 189 return items.get(index); 190 } 191 192 float random; 193 if (this.random != null) { 194 random = this.random.nextFloat() * total; 195 } else { 196 random = (float) (Math.random() * total); 197 } 198 if (random > total) random = total; 199 //random = 0.1f; 200 //random = total - 0.001f; 201 float weightSoFar = 0f; 202 int index = 0; 203 for (Float weight : weights) { 204 weightSoFar += weight; 205 if (random <= weightSoFar) break; 206 index++; 207 } 208 return items.get(Math.min(index, items.size() - 1)); 209 } 210 211 public Random getRandom() { 212 return random; 213 } 214 215 public void setRandom(Random random) { 216 this.random = random; 217 } 218 219 220 221 public void print(String title) { 222 System.out.println(title); 223 224 Map<T, Integer> indices = new HashMap<T, Integer>(); 225 for (int i = 0; i < items.size(); i++) { 226 T item = items.get(i); 227 indices.put(item, i); 228 } 229 230 List<T> sorted = new ArrayList<T>(items); 231 Collections.sort(sorted, new Comparator<T>() { 232 public int compare(T o1, T o2) { 233 return o1.toString().compareTo(o2.toString()); 234 } 235 }); 236 237 for (T item : sorted) { 238 int index = indices.get(item); 239 float weight = weights.get(index); 240 //String percent = Misc.getRoundedValueMaxOneAfterDecimal((weight / total) * 100f) + "%"; 241 String percent = "" + (int)((weight / total) * 100f) + "%"; 242 243 //System.out.println(" " + item.toString() + ": " + percent + " (" + Misc.getRoundedValue(weight) + ")"); 244 String itemStr = ""; 245 if (item instanceof MarketAPI) { 246 itemStr = ((MarketAPI)item).getName(); 247 } else { 248 itemStr = item.toString(); 249 } 250 System.out.println(String.format(" %-30s%10s%10s", itemStr, percent, Misc.getRoundedValue(weight))); 251 } 252 //System.out.println(" Total: " + (int) total); 253 //System.out.println(); 254 } 255 256 public float getTotal() { 257 return total; 258 } 259 260} 261 262 263 264 265 266