Moved the CSRF token client over into main commons code (#5471)

This commit is contained in:
Paul Hawke 2024-01-23 19:36:43 -06:00 committed by GitHub
parent 3d0e65c92c
commit 8b8eb84fae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 252 additions and 37 deletions

View file

@ -2,7 +2,7 @@ package fr.free.nrw.commons.actions
import io.reactivex.Observable
import io.reactivex.Single
import org.wikipedia.csrf.CsrfTokenClient
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
/**
* This class acts as a Client to facilitate wiki page editing

View file

@ -3,7 +3,7 @@ package fr.free.nrw.commons.actions
import fr.free.nrw.commons.CommonsApplication
import fr.free.nrw.commons.di.NetworkingModule.NAMED_COMMONS_CSRF
import io.reactivex.Observable
import org.wikipedia.csrf.CsrfTokenClient
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
import javax.inject.Inject
import javax.inject.Named
import javax.inject.Singleton

View file

@ -0,0 +1,247 @@
package fr.free.nrw.commons.auth.csrf;
import android.text.TextUtils;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.VisibleForTesting;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import org.wikipedia.AppAdapter;
import org.wikipedia.dataclient.Service;
import org.wikipedia.dataclient.ServiceFactory;
import org.wikipedia.dataclient.SharedPreferenceCookieManager;
import org.wikipedia.dataclient.WikiSite;
import org.wikipedia.dataclient.mwapi.MwQueryResponse;
import org.wikipedia.login.LoginClient;
import org.wikipedia.login.LoginResult;
import org.wikipedia.util.log.L;
import java.io.IOException;
import retrofit2.Call;
import retrofit2.Response;
public class CsrfTokenClient {
private static final String ANON_TOKEN = "+\\";
private static final int MAX_RETRIES = 1;
private static final int MAX_RETRIES_OF_LOGIN_BLOCKING = 2;
@NonNull private final WikiSite csrfWikiSite;
@NonNull private final WikiSite loginWikiSite;
private int retries = 0;
@Nullable private Call<MwQueryResponse> csrfTokenCall;
@NonNull private LoginClient loginClient = new LoginClient();
public CsrfTokenClient(@NonNull WikiSite csrfWikiSite, @NonNull WikiSite loginWikiSite) {
this.csrfWikiSite = csrfWikiSite;
this.loginWikiSite = loginWikiSite;
}
public void request(@NonNull final Callback callback) {
request(false, callback);
}
public void request(boolean forceLogin, @NonNull final Callback callback) {
cancel();
if (forceLogin) {
retryWithLogin(new RuntimeException("Forcing login..."), callback);
return;
}
csrfTokenCall = request(ServiceFactory.get(csrfWikiSite), callback);
}
public void cancel() {
loginClient.cancel();
if (csrfTokenCall != null) {
csrfTokenCall.cancel();
csrfTokenCall = null;
}
}
@VisibleForTesting
@NonNull
Call<MwQueryResponse> request(@NonNull Service service, @NonNull final Callback cb) {
return requestToken(service, new CsrfTokenClient.Callback() {
@Override public void success(@NonNull String token) {
if (AppAdapter.get().isLoggedIn() && token.equals(ANON_TOKEN)) {
retryWithLogin(new RuntimeException("App believes we're logged in, but got anonymous token."), cb);
} else {
cb.success(token);
}
}
@Override public void failure(@NonNull Throwable caught) {
retryWithLogin(caught, cb);
}
@Override
public void twoFactorPrompt() {
cb.twoFactorPrompt();
}
});
}
private void retryWithLogin(@NonNull Throwable caught, @NonNull final Callback callback) {
if (retries < MAX_RETRIES
&& !TextUtils.isEmpty(AppAdapter.get().getUserName())
&& !TextUtils.isEmpty(AppAdapter.get().getPassword())) {
retries++;
SharedPreferenceCookieManager.getInstance().clearAllCookies();
login(AppAdapter.get().getUserName(), AppAdapter.get().getPassword(), () -> {
L.i("retrying...");
request(callback);
}, callback);
} else {
callback.failure(caught);
}
}
private void login(@NonNull final String username, @NonNull final String password,
@NonNull final RetryCallback retryCallback,
@NonNull final Callback callback) {
new LoginClient().request(loginWikiSite, username, password,
new LoginClient.LoginCallback() {
@Override
public void success(@NonNull LoginResult loginResult) {
if (loginResult.pass()) {
AppAdapter.get().updateAccount(loginResult);
retryCallback.retry();
} else {
callback.failure(new LoginClient.LoginFailedException(loginResult.getMessage()));
}
}
@Override
public void twoFactorPrompt(@NonNull Throwable caught, @Nullable String token) {
callback.twoFactorPrompt();
}
@Override public void passwordResetPrompt(@Nullable String token) {
// Should not happen here, but call the callback just in case.
callback.failure(new LoginClient.LoginFailedException("Logged in with temporary password."));
}
@Override
public void error(@NonNull Throwable caught) {
callback.failure(caught);
}
});
}
@NonNull public String getTokenBlocking() throws Throwable {
String token = "";
Service service = ServiceFactory.get(csrfWikiSite);
for (int retry = 0; retry < MAX_RETRIES_OF_LOGIN_BLOCKING; retry++) {
try {
if (retry > 0) {
// Log in explicitly
new LoginClient().loginBlocking(loginWikiSite, AppAdapter.get().getUserName(),
AppAdapter.get().getPassword(), "");
}
// Get CSRFToken response off the main thread.
Response<MwQueryResponse> response = Executors.newSingleThreadExecutor().submit(new CsrfTokenCallExecutor(service)).get();
if (response.body() == null || response.body().query() == null
|| TextUtils.isEmpty(response.body().query().csrfToken())) {
continue;
}
token = response.body().query().csrfToken();
if (AppAdapter.get().isLoggedIn() && token.equals(ANON_TOKEN)) {
throw new RuntimeException("App believes we're logged in, but got anonymous token.");
}
break;
} catch (Throwable t) {
L.w(t);
}
}
if (TextUtils.isEmpty(token) || token.equals(ANON_TOKEN)) {
throw new IOException("Invalid token, or login failure.");
}
return token;
}
@VisibleForTesting @NonNull Call<MwQueryResponse> requestToken(@NonNull Service service,
@NonNull final Callback cb) {
Call<MwQueryResponse> call = service.getCsrfTokenCall();
call.enqueue(new retrofit2.Callback<MwQueryResponse>() {
@Override
public void onResponse(@NonNull Call<MwQueryResponse> call, @NonNull Response<MwQueryResponse> response) {
if (call.isCanceled()) {
return;
}
cb.success(response.body().query().csrfToken());
}
@Override
public void onFailure(@NonNull Call<MwQueryResponse> call, @NonNull Throwable t) {
if (call.isCanceled()) {
return;
}
cb.failure(t);
}
});
return call;
}
public interface Callback {
void success(@NonNull String token);
void failure(@NonNull Throwable caught);
void twoFactorPrompt();
}
public static class DefaultCallback implements Callback {
@Override
public void success(@NonNull String token) {
}
@Override
public void failure(@NonNull Throwable caught) {
L.e(caught);
}
@Override
public void twoFactorPrompt() {
// TODO:
}
}
private interface RetryCallback {
void retry();
}
/**
* Class CsrfTokenCallExecutor which implement callable interface to get CsrfTokenCall.
*/
class CsrfTokenCallExecutor implements Callable<Response<MwQueryResponse>> {
/**
* Service for token call.
*/
private Service service;
/**
* Default Constructor.
* @param service
*/
public CsrfTokenCallExecutor(Service service){
this.service = service;
}
/**
* Computes a result, or throws an exception if unable to do so.
*
* @return computed result
* @throws Exception if unable to compute a result
*/
@Override
public Response<MwQueryResponse> call() throws Exception {
return service.getCsrfTokenCall().execute();
}
}
}

View file

@ -35,7 +35,7 @@ import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.logging.HttpLoggingInterceptor.Level;
import org.wikipedia.csrf.CsrfTokenClient;
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient;
import org.wikipedia.dataclient.Service;
import org.wikipedia.dataclient.ServiceFactory;
import org.wikipedia.dataclient.WikiSite;

View file

@ -5,7 +5,7 @@ import fr.free.nrw.commons.notification.models.Notification
import fr.free.nrw.commons.notification.models.NotificationType
import io.reactivex.Observable
import io.reactivex.Single
import org.wikipedia.csrf.CsrfTokenClient
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
import org.wikipedia.dataclient.mwapi.MwQueryResponse
import org.wikipedia.util.DateUtil
import javax.inject.Inject

View file

@ -26,7 +26,7 @@ import javax.inject.Singleton;
import okhttp3.MediaType;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import org.wikipedia.csrf.CsrfTokenClient;
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient;
import org.wikipedia.dataclient.mwapi.MwException;
import timber.log.Timber;

View file

@ -9,7 +9,7 @@ import io.reactivex.Observable;
import javax.inject.Inject;
import javax.inject.Named;
import javax.inject.Singleton;
import org.wikipedia.csrf.CsrfTokenClient;
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient;
import org.wikipedia.dataclient.mwapi.MwPostResponse;
import timber.log.Timber;

View file

@ -0,0 +1,111 @@
package fr.free.nrw.commons;
import androidx.annotation.NonNull;
import java.util.List;
import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import okhttp3.Dispatcher;
import okhttp3.OkHttpClient;
import okhttp3.mockwebserver.MockResponse;
import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.wikipedia.AppAdapter;
import org.wikipedia.dataclient.Service;
import org.wikipedia.dataclient.WikiSite;
import org.wikipedia.json.GsonUtil;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;
@RunWith(RobolectricTestRunner.class)
public abstract class MockWebServerTest {
private OkHttpClient okHttpClient;
private final TestWebServer server = new TestWebServer();
@Before public void setUp() throws Throwable {
AppAdapter.set(new TestAppAdapter());
OkHttpClient.Builder builder = AppAdapter.get().getOkHttpClient(new WikiSite(Service.WIKIPEDIA_URL)).newBuilder();
okHttpClient = builder.dispatcher(new Dispatcher(new ImmediateExecutorService())).build();
server.setUp();
}
@After public void tearDown() throws Throwable {
server.tearDown();
}
@NonNull protected TestWebServer server() {
return server;
}
protected void enqueueFromFile(@NonNull String filename) throws Throwable {
String json = TestFileUtil.readRawFile(filename);
server.enqueue(json);
}
protected void enqueue404() {
final int code = 404;
server.enqueue(new MockResponse().setResponseCode(code).setBody("Not Found"));
}
protected void enqueueMalformed() {
server.enqueue("(╯°□°)╯︵ ┻━┻");
}
protected void enqueueEmptyJson() {
server.enqueue(new MockResponse().setBody("{}"));
}
@NonNull protected OkHttpClient okHttpClient() {
return okHttpClient;
}
@NonNull protected <T> T service(Class<T> clazz) {
return service(clazz, server().getUrl());
}
@NonNull protected <T> T service(Class<T> clazz, @NonNull String url) {
return new Retrofit.Builder()
.baseUrl(url)
.callbackExecutor(new ImmediateExecutor())
.client(okHttpClient)
.addConverterFactory(GsonConverterFactory.create(GsonUtil.getDefaultGson()))
.build()
.create(clazz);
}
public final class ImmediateExecutorService extends AbstractExecutorService {
@Override public void shutdown() {
throw new UnsupportedOperationException();
}
@NonNull @Override public List<Runnable> shutdownNow() {
throw new UnsupportedOperationException();
}
@Override public boolean isShutdown() {
throw new UnsupportedOperationException();
}
@Override public boolean isTerminated() {
throw new UnsupportedOperationException();
}
@Override public boolean awaitTermination(long l, @NonNull TimeUnit timeUnit)
throws InterruptedException {
throw new UnsupportedOperationException();
}
@Override public void execute(@NonNull Runnable runnable) {
runnable.run();
}
}
public class ImmediateExecutor implements Executor {
@Override
public void execute(@NonNull Runnable runnable) {
runnable.run();
}
}
}

View file

@ -0,0 +1,37 @@
package fr.free.nrw.commons;
import android.annotation.TargetApi;
import androidx.annotation.NonNull;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
public final class TestFileUtil {
private static final String RAW_DIR = "src/test/res/raw/";
public static File getRawFile(@NonNull String rawFileName) {
return new File(RAW_DIR + rawFileName);
}
public static String readRawFile(String basename) throws IOException {
return readFile(getRawFile(basename));
}
@TargetApi(19)
private static String readFile(File file) throws IOException {
return FileUtils.readFileToString(file, StandardCharsets.UTF_8);
}
@TargetApi(19)
public static String readStream(InputStream stream) throws IOException {
StringWriter writer = new StringWriter();
IOUtils.copy(stream, writer, StandardCharsets.UTF_8);
return writer.toString();
}
private TestFileUtil() { }
}

View file

@ -0,0 +1,56 @@
package fr.free.nrw.commons;
import androidx.annotation.NonNull;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
public class TestWebServer {
public static final int TIMEOUT_DURATION = 5;
public static final TimeUnit TIMEOUT_UNIT = TimeUnit.SECONDS;
private final MockWebServer server;
public TestWebServer() {
server = new MockWebServer();
}
public void setUp() throws IOException {
server.start();
}
public void tearDown() throws IOException {
server.shutdown();
}
public String getUrl() {
return getUrl("");
}
public String getUrl(String path) {
return server.url(path).url().toString();
}
public int getRequestCount() {
return server.getRequestCount();
}
public void enqueue(@NonNull String body) {
enqueue(new MockResponse().setBody(body));
}
public void enqueue(MockResponse response) {
server.enqueue(response);
}
@NonNull public RecordedRequest takeRequest() throws InterruptedException {
RecordedRequest req = server.takeRequest(TIMEOUT_DURATION,
TIMEOUT_UNIT);
if (req == null) {
throw new InterruptedException("Timeout elapsed.");
}
return req;
}
}

View file

@ -9,8 +9,7 @@ import org.mockito.ArgumentMatchers
import org.mockito.Mock
import org.mockito.Mockito
import org.mockito.MockitoAnnotations
import org.wikipedia.csrf.CsrfTokenClient
import org.wikipedia.dataclient.Service
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
import org.wikipedia.edit.Edit
class PageEditClientTest {

View file

@ -12,12 +12,9 @@ import org.mockito.MockedStatic
import org.mockito.Mockito
import org.mockito.Mockito.`when`
import org.mockito.MockitoAnnotations
import org.powermock.api.mockito.PowerMockito
import org.powermock.core.classloader.annotations.PrepareForTest
import org.powermock.modules.junit4.PowerMockRunner
import org.robolectric.RobolectricTestRunner
import org.wikipedia.csrf.CsrfTokenClient
import org.wikipedia.dataclient.Service
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
@RunWith(RobolectricTestRunner::class)
@PrepareForTest(CommonsApplication::class)

View file

@ -0,0 +1,82 @@
package fr.free.nrw.commons.auth.csrf;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import androidx.annotation.NonNull;
import com.google.gson.stream.MalformedJsonException;
import fr.free.nrw.commons.MockWebServerTest;
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient.Callback;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.wikipedia.dataclient.Service;
import org.wikipedia.dataclient.WikiSite;
import org.wikipedia.dataclient.mwapi.MwException;
import org.wikipedia.dataclient.mwapi.MwQueryResponse;
import org.wikipedia.dataclient.okhttp.HttpStatusException;
import retrofit2.Call;
public class CsrfTokenClientTest extends MockWebServerTest {
private static final WikiSite TEST_WIKI = new WikiSite("test.wikipedia.org");
@NonNull private final CsrfTokenClient subject = new CsrfTokenClient(TEST_WIKI, TEST_WIKI);
@Test public void testRequestSuccess() throws Throwable {
String expected = "b6f7bd58c013ab30735cb19ecc0aa08258122cba+\\";
enqueueFromFile("csrf_token.json");
Callback cb = Mockito.mock(Callback.class);
request(cb);
server().takeRequest();
assertCallbackSuccess(cb, expected);
}
@Test public void testRequestResponseApiError() throws Throwable {
enqueueFromFile("api_error.json");
Callback cb = Mockito.mock(Callback.class);
request(cb);
server().takeRequest();
assertCallbackFailure(cb, MwException.class);
}
@Test public void testRequestResponseFailure() throws Throwable {
enqueue404();
Callback cb = Mockito.mock(Callback.class);
request(cb);
server().takeRequest();
assertCallbackFailure(cb, HttpStatusException.class);
}
@Test public void testRequestResponseMalformed() throws Throwable {
enqueueMalformed();
Callback cb = Mockito.mock(Callback.class);
request(cb);
server().takeRequest();
assertCallbackFailure(cb, MalformedJsonException.class);
}
private void assertCallbackSuccess(@NonNull Callback cb,
@NonNull String expected) {
verify(cb).success(ArgumentMatchers.eq(expected));
//noinspection unchecked
verify(cb, never()).failure(ArgumentMatchers.any(Throwable.class));
}
private void assertCallbackFailure(@NonNull Callback cb,
@NonNull Class<? extends Throwable> throwable) {
//noinspection unchecked
verify(cb, never()).success(ArgumentMatchers.any(String.class));
verify(cb).failure(ArgumentMatchers.isA(throwable));
}
private Call<MwQueryResponse> request(@NonNull Callback cb) {
return subject.request(service(Service.class), cb);
}
}

View file

@ -17,7 +17,7 @@ import org.mockito.Mockito.verify
import org.mockito.MockitoAnnotations
import org.robolectric.annotation.Config
import org.robolectric.annotation.LooperMode
import org.wikipedia.csrf.CsrfTokenClient
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
import org.wikipedia.dataclient.mwapi.MwQueryResponse
import org.wikipedia.dataclient.mwapi.MwQueryResult
import org.wikipedia.json.GsonUtil

View file

@ -6,7 +6,7 @@ import org.mockito.InjectMocks
import org.mockito.Mock
import org.mockito.Mockito
import org.mockito.MockitoAnnotations
import org.wikipedia.csrf.CsrfTokenClient
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient
class WikiBaseClientUnitTest {

View file

@ -0,0 +1,10 @@
{
"errors": [
{
"code": "unknown_action",
"text": "Unrecognized value for parameter \"action\": oscillate."
}
],
"docref": "See https://en.wikipedia.org/w/api.php for API usage.",
"servedby": "mw1286"
}

View file

@ -0,0 +1,8 @@
{
"batchcomplete": true,
"query": {
"tokens": {
"csrftoken": "b6f7bd58c013ab30735cb19ecc0aa08258122cba+\\"
}
}
}