package launcher.request.update;
import launcher.Launcher.Config;
import launcher.LauncherAPI;
import launcher.hasher.FileNameMatcher;
import launcher.hasher.HashedDir;
import launcher.hasher.HashedDir.Diff;
import launcher.hasher.HashedEntry;
import launcher.hasher.HashedFile;
import launcher.helper.IOHelper;
import launcher.helper.SecurityHelper;
import launcher.helper.SecurityHelper.DigestAlgorithm;
import launcher.request.Request;
import launcher.request.update.UpdateRequest.State.Callback;
import launcher.serialize.HInput;
import launcher.serialize.HOutput;
import launcher.serialize.signed.SignedObjectHolder;
import launcher.serialize.stream.EnumSerializer;
import launcher.serialize.stream.EnumSerializer.Itf;
import launcher.serialize.stream.StreamObject;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.MessageDigest;
import java.security.SignatureException;
import java.time.Duration;
import java.time.Instant;
import java.util.LinkedList;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Queue;
import java.util.zip.InflaterInputStream;
public final class UpdateRequest extends Request<SignedObjectHolder<HashedDir>>
{
@LauncherAPI
public static final int MAX_QUEUE_SIZE = 128;
// Instance
private final String dirName;
private final Path dir;
private final FileNameMatcher matcher;
private final boolean digest;
private volatile Callback stateCallback;
// State
private HashedDir localDir;
private long totalDownloaded;
private long totalSize;
private Instant startTime;
@LauncherAPI
public UpdateRequest(Config config, String dirName, Path dir, FileNameMatcher matcher, boolean digest)
{
super(config);
this.dirName = IOHelper.verifyFileName(dirName);
this.dir = Objects.requireNonNull(dir, "dir");
this.matcher = matcher;
this.digest = digest;
}
@LauncherAPI
public UpdateRequest(String dirName, Path dir, FileNameMatcher matcher, boolean digest)
{
this(null, dirName, dir, matcher, digest);
}
private static void fillActionsQueue(Queue<Action> queue, HashedDir mismatch)
{
for (Entry<String, HashedEntry> mapEntry : mismatch.map().entrySet())
{
String name = mapEntry.getKey();
HashedEntry entry = mapEntry.getValue();
HashedEntry.Type entryType = entry.getType();
switch (entryType)
{
case DIR: // cd - get - cd ..
queue.add(new Action(Action.Type.CD, name, entry));
fillActionsQueue(queue, (HashedDir) entry);
queue.add(Action.CD_BACK);
break;
case FILE: // get
queue.add(new Action(Action.Type.GET, name, entry));
break;
default:
throw new AssertionError("Unsupported hashed entry type: " + entryType.name());
}
}
}
@Override
public Type getType()
{
return Type.UPDATE;
}
@Override
public SignedObjectHolder<HashedDir> request() throws Throwable
{
Files.createDirectories(dir);
localDir = new HashedDir(dir, matcher, false, digest);
// Start request
return super.request();
}
@Override
protected SignedObjectHolder<HashedDir> requestDo(HInput input, HOutput output) throws IOException, SignatureException
{
// Write update dir name
output.writeString(dirName, 255);
output.flush();
readError(input);
// Get diff between local and remote dir
SignedObjectHolder<HashedDir> remoteHDirHolder = new SignedObjectHolder<>(input, config.publicKey, HashedDir::new);
Diff diff = remoteHDirHolder.object.diff(localDir, matcher);
totalSize = diff.mismatch.size();
boolean compress = input.readBoolean();
// Build actions queue
Queue<Action> queue = new LinkedList<>();
fillActionsQueue(queue, diff.mismatch);
queue.add(Action.FINISH);
// noinspection IOResourceOpenedButNotSafelyClosed
InputStream fileInput = compress ? new InflaterInputStream(input.stream, IOHelper.newInflater(), IOHelper.BUFFER_SIZE) : input.stream;
// Download missing first
// (otherwise it will cause mustdie indexing bug)
startTime = Instant.now();
Path currentDir = dir;
Action[] actionsSlice = new Action[MAX_QUEUE_SIZE];
while (!queue.isEmpty())
{
int length = Math.min(queue.size(), MAX_QUEUE_SIZE);
// Write actions slice
output.writeLength(length, MAX_QUEUE_SIZE);
for (int i = 0; i < length; i++)
{
Action action = queue.remove();
actionsSlice[i] = action;
action.write(output);
}
output.flush();
// Perform actions
for (int i = 0; i < length; i++)
{
Action action = actionsSlice[i];
switch (action.type)
{
case CD:
currentDir = currentDir.resolve(action.name);
Files.createDirectories(currentDir);
break;
case GET:
Path targetFile = currentDir.resolve(action.name);
if (fileInput.read() != 0xFF)
{
throw new IOException("Serverside cached size mismath for file " + action.name);
}
downloadFile(targetFile, (HashedFile) action.entry, fileInput);
break;
case CD_BACK:
currentDir = currentDir.getParent();
break;
case FINISH:
break;
default:
throw new AssertionError(String.format("Unsupported action type: '%s'", action.type.name()));
}
}
}
// Write update completed packet
deleteExtraDir(dir, diff.extra, diff.extra.flag);
return remoteHDirHolder;
}
@LauncherAPI
public void setStateCallback(Callback callback)
{
stateCallback = callback;
}
private void deleteExtraDir(Path subDir, HashedDir subHDir, boolean flag) throws IOException
{
for (Entry<String, HashedEntry> mapEntry : subHDir.map().entrySet())
{
String name = mapEntry.getKey();
Path path = subDir.resolve(name);
// Delete files and dirs based on type
HashedEntry entry = mapEntry.getValue();
HashedEntry.Type entryType = entry.getType();
switch (entryType)
{
case FILE:
updateState(IOHelper.toString(path), 0, 0);
Files.delete(path);
break;
case DIR:
deleteExtraDir(path, (HashedDir) entry, flag || entry.flag);
break;
default:
throw new AssertionError("Unsupported hashed entry type: " + entryType.name());
}
}
// Delete!
if (flag)
{
updateState(IOHelper.toString(subDir), 0, 0);
Files.delete(subDir);
}
}
private void downloadFile(Path file, HashedFile hFile, InputStream input) throws IOException
{
String filePath = IOHelper.toString(dir.relativize(file));
updateState(filePath, 0L, hFile.size);
// Start file update
MessageDigest digest = this.digest ? SecurityHelper.newDigest(DigestAlgorithm.MD5) : null;
try (OutputStream fileOutput = IOHelper.newOutput(file))
{
long downloaded = 0L;
// Download with digest update
byte[] bytes = IOHelper.newBuffer();
while (downloaded < hFile.size)
{
int remaining = (int) Math.min(hFile.size - downloaded, bytes.length);
int length = input.read(bytes, 0, remaining);
if (length < 0)
{
throw new EOFException(String.format("%d bytes remaining", hFile.size - downloaded));
}
// Update file
fileOutput.write(bytes, 0, length);
if (digest != null)
{
digest.update(bytes, 0, length);
}
// Update state
downloaded += length;
totalDownloaded += length;
updateState(filePath, downloaded, hFile.size);
}
}
// Verify digest
if (digest != null)
{
byte[] digestBytes = digest.digest();
if (!hFile.isSameDigest(digestBytes))
{
throw new SecurityException(String.format("File digest mismatch: '%s'", filePath));
}
}
}
private void updateState(String filePath, long fileDownloaded, long fileSize)
{
if (stateCallback != null)
{
stateCallback.call(new State(filePath, fileDownloaded, fileSize,
totalDownloaded, totalSize, Duration.between(startTime, Instant.now())));
}
}
public static final class Action extends StreamObject
{
public static final Action CD_BACK = new Action(Type.CD_BACK, null, null);
public static final Action FINISH = new Action(Type.FINISH, null, null);
// Instance
public final Type type;
public final String name;
public final HashedEntry entry;
public Action(Type type, String name, HashedEntry entry)
{
this.type = type;
this.name = name;
this.entry = entry;
}
public Action(HInput input) throws IOException
{
type = Type.read(input);
name = type == Type.CD || type == Type.GET ? IOHelper.verifyFileName(input.readString(255)) : null;
entry = null;
}
@Override
public void write(HOutput output) throws IOException
{
EnumSerializer.write(output, type);
if (type == Type.CD || type == Type.GET)
{
output.writeString(name, 255);
}
}
public enum Type implements Itf
{
CD(1), CD_BACK(2), GET(3), FINISH(255);
private static final EnumSerializer<Type> SERIALIZER = new EnumSerializer<>(Type.class);
private final int n;
Type(int n)
{
this.n = n;
}
public static Type read(HInput input) throws IOException
{
return SERIALIZER.read(input);
}
@Override
public int getNumber()
{
return n;
}
}
}
public static final class State
{
@LauncherAPI
public final long fileDownloaded;
@LauncherAPI
public final long fileSize;
@LauncherAPI
public final long totalDownloaded;
@LauncherAPI
public final long totalSize;
@LauncherAPI
public final String filePath;
@LauncherAPI
public final Duration duration;
public State(String filePath, long fileDownloaded, long fileSize, long totalDownloaded, long totalSize, Duration duration)
{
this.filePath = filePath;
this.fileDownloaded = fileDownloaded;
this.fileSize = fileSize;
this.totalDownloaded = totalDownloaded;
this.totalSize = totalSize;
// Also store time of creation
this.duration = duration;
}
@LauncherAPI
public double getBps()
{
long seconds = duration.getSeconds();
if (seconds == 0)
{
return -1.0D; // Otherwise will throw /0 exception
}
return totalDownloaded / (double) seconds;
}
@LauncherAPI
public Duration getEstimatedTime()
{
double bps = getBps();
if (bps <= 0.0D)
{
return null; // Otherwise will throw /0 exception
}
return Duration.ofSeconds((long) (getTotalRemaining() / bps));
}
@LauncherAPI
public double getFileDownloadedKiB()
{
return fileDownloaded / 1024.0D;
}
@LauncherAPI
public double getFileDownloadedMiB()
{
return getFileDownloadedKiB() / 1024.0D;
}
@LauncherAPI
public double getFileDownloadedPart()
{
if (fileSize == 0)
{
return 0.0D;
}
return (double) fileDownloaded / fileSize;
}
@LauncherAPI
public long getFileRemaining()
{
return fileSize - fileDownloaded;
}
@LauncherAPI
public double getFileRemainingKiB()
{
return getFileRemaining() / 1024.0D;
}
@LauncherAPI
public double getFileRemainingMiB()
{
return getFileRemainingKiB() / 1024.0D;
}
@LauncherAPI
public double getFileSizeKiB()
{
return fileSize / 1024.0D;
}
@LauncherAPI
public double getFileSizeMiB()
{
return getFileSizeKiB() / 1024.0D;
}
@LauncherAPI
public double getTotalDownloadedKiB()
{
return totalDownloaded / 1024.0D;
}
@LauncherAPI
public double getTotalDownloadedMiB()
{
return getTotalDownloadedKiB() / 1024.0D;
}
@LauncherAPI
public double getTotalDownloadedPart()
{
if (totalSize == 0)
{
return 0.0D;
}
return (double) totalDownloaded / totalSize;
}
@LauncherAPI
public long getTotalRemaining()
{
return totalSize - totalDownloaded;
}
@LauncherAPI
public double getTotalRemainingKiB()
{
return getTotalRemaining() / 1024.0D;
}
@LauncherAPI
public double getTotalRemainingMiB()
{
return getTotalRemainingKiB() / 1024.0D;
}
@LauncherAPI
public double getTotalSizeKiB()
{
return totalSize / 1024.0D;
}
@LauncherAPI
public double getTotalSizeMiB()
{
return getTotalSizeKiB() / 1024.0D;
}
@FunctionalInterface
public interface Callback
{
void call(State state);
}
}
}