🤬
  • Add Run RPC protocol handler for the Java client side.

    PiperOrigin-RevId: 460236536
    Change-Id: I5d2c619c3e8ac84d3ad2f03fc26af8b537388b42
  • Loading...
  • John Y. Kim committed with Copybara-Service 2 years ago
    094dcf00
    1 parent 4f3d7079
Revision indexing in progress... (symbol navigation in revisions will be accurate after indexed)
  • ■ ■ ■ ■ ■ ■
    plugin/src/main/java/com/google/tsunami/plugin/PluginServiceClient.java
    skipped 16 lines
    17 17   
    18 18  import static com.google.common.base.Preconditions.checkNotNull;
    19 19   
     20 +import com.google.common.util.concurrent.Futures;
     21 +import com.google.common.util.concurrent.ListenableFuture;
    20 22  import com.google.common.util.concurrent.ListeningScheduledExecutorService;
    21 23  import com.google.tsunami.proto.PluginServiceGrpc;
    22 24  import com.google.tsunami.proto.PluginServiceGrpc.PluginServiceFutureStub;
     25 +import com.google.tsunami.proto.RunRequest;
     26 +import com.google.tsunami.proto.RunResponse;
    23 27  import io.grpc.Channel;
     28 +import java.time.Duration;
    24 29   
    25 30  /**
    26 31   * Client side gRPC handler for the PluginService RPC protocol. Main handler for all gRPC calls to
    27 32   * the language-specific servers.
    28 33   */
    29  -public class PluginServiceClient {
     34 +public final class PluginServiceClient {
    30 35   
    31 36   private final PluginServiceFutureStub pluginService;
    32 37   private final ListeningScheduledExecutorService scheduledExecutorService;
    skipped 1 lines
    34 39   PluginServiceClient(Channel channel, ListeningScheduledExecutorService service) {
    35 40   this.pluginService = PluginServiceGrpc.newFutureStub(checkNotNull(channel));
    36 41   this.scheduledExecutorService = checkNotNull(service);
     42 + }
     43 + 
     44 + /**
     45 + * Sends a run request to the gRPC language server with a specified deadline.
     46 + *
     47 + * @param request The main request containing plugins to run.
     48 + * @param deadline The timeout of the service.
     49 + * @return The future of the run response.
     50 + */
     51 + public ListenableFuture<RunResponse> runWithDeadline(RunRequest request, Duration deadline) {
     52 + return Futures.withTimeout(pluginService.run(request), deadline, scheduledExecutorService);
    37 53   }
    38 54  }
    39 55   
  • ■ ■ ■ ■ ■ ■
    plugin/src/test/java/com/google/tsunami/plugin/PluginServiceClientTest.java
    skipped 16 lines
    17 17   
    18 18  import static com.google.common.truth.Truth.assertThat;
    19 19   
     20 +import com.google.common.util.concurrent.ListenableFuture;
    20 21  import com.google.common.util.concurrent.ListeningScheduledExecutorService;
    21 22  import com.google.inject.Guice;
    22 23  import com.google.inject.Key;
    23 24  import com.google.tsunami.common.concurrent.ScheduledThreadPoolModule;
     25 +import com.google.tsunami.common.data.NetworkEndpointUtils;
     26 +import com.google.tsunami.proto.DetectionReport;
     27 +import com.google.tsunami.proto.DetectionReportList;
     28 +import com.google.tsunami.proto.MatchedPlugin;
     29 +import com.google.tsunami.proto.NetworkEndpoint;
     30 +import com.google.tsunami.proto.NetworkService;
     31 +import com.google.tsunami.proto.PluginDefinition;
     32 +import com.google.tsunami.proto.PluginInfo;
     33 +import com.google.tsunami.proto.PluginServiceGrpc.PluginServiceImplBase;
     34 +import com.google.tsunami.proto.RunRequest;
     35 +import com.google.tsunami.proto.RunResponse;
     36 +import com.google.tsunami.proto.TargetInfo;
     37 +import com.google.tsunami.proto.TransportProtocol;
    24 38  import io.grpc.inprocess.InProcessChannelBuilder;
    25 39  import io.grpc.inprocess.InProcessServerBuilder;
     40 +import io.grpc.stub.StreamObserver;
    26 41  import io.grpc.testing.GrpcCleanupRule;
    27 42  import io.grpc.util.MutableHandlerRegistry;
    28 43  import java.lang.annotation.Retention;
    29 44  import java.lang.annotation.RetentionPolicy;
     45 +import java.time.Duration;
     46 +import java.util.ArrayList;
     47 +import java.util.List;
    30 48  import javax.inject.Qualifier;
    31 49  import org.junit.Before;
    32 50  import org.junit.Rule;
    skipped 5 lines
    38 56  @RunWith(JUnit4.class)
    39 57  public final class PluginServiceClientTest {
    40 58   
     59 + // TODO(b/236740807): Create a wrapper for results and errors.
     60 + 
    41 61   // Useful test thread pool used for testing grpc handlers
    42 62   @Qualifier
    43 63   @Retention(RetentionPolicy.RUNTIME)
    skipped 4 lines
    48 68   private static final int THREAD_POOLS = 1;
    49 69   private static final String THREAD_POOL_NAME = "test";
    50 70   
     71 + private static final String PLUGIN_NAME = "test plugin";
     72 + private static final String PLUGIN_VERSION = "0.0.1";
     73 + private static final String PLUGIN_DESCRIPTION = "test description";
     74 + private static final String PLUGIN_AUTHOR = "tester";
     75 + 
     76 + private static final Duration DURATION_DEFAULT = Duration.ofSeconds(1);
     77 + 
    51 78   private PluginServiceClient pluginService;
    52 79   private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
    53 80   
    skipped 25 lines
    79 106   assertThat(pluginService).isNotNull();
    80 107   }
    81 108   
     109 + @Test
     110 + public void run_invalidRequest_returnNoDetectionReports() throws Exception {
     111 + RunRequest runRequest = RunRequest.getDefaultInstance();
     112 + PluginServiceImplBase runImpl =
     113 + new PluginServiceImplBase() {
     114 + @Override
     115 + public void run(RunRequest request, StreamObserver<RunResponse> responseObserver) {
     116 + responseObserver.onNext(RunResponse.getDefaultInstance());
     117 + responseObserver.onCompleted();
     118 + }
     119 + };
     120 + serviceRegistry.addService(runImpl);
     121 + 
     122 + ListenableFuture<RunResponse> run = pluginService.runWithDeadline(runRequest, DURATION_DEFAULT);
     123 + RunResponse runResponse = run.get();
     124 + 
     125 + assertThat(run.isDone()).isTrue();
     126 + assertThat(runResponse.hasReports()).isFalse();
     127 + }
     128 + 
     129 + @Test
     130 + public void run_singlePluginValidRequest_returnSingleDetectionReport() throws Exception {
     131 + RunRequest runRequest = createSinglePluginRunRequest();
     132 + PluginServiceImplBase runImpl =
     133 + new PluginServiceImplBase() {
     134 + @Override
     135 + public void run(RunRequest request, StreamObserver<RunResponse> responseObserver) {
     136 + DetectionReportList reportList =
     137 + DetectionReportList.newBuilder()
     138 + .addDetectionReports(
     139 + DetectionReport.newBuilder()
     140 + .setTargetInfo(request.getTarget())
     141 + .setNetworkService(request.getPlugins(0).getServices(0)))
     142 + .build();
     143 + responseObserver.onNext(RunResponse.newBuilder().setReports(reportList).build());
     144 + responseObserver.onCompleted();
     145 + }
     146 + };
     147 + serviceRegistry.addService(runImpl);
     148 + 
     149 + ListenableFuture<RunResponse> run = pluginService.runWithDeadline(runRequest, DURATION_DEFAULT);
     150 + RunResponse runResponse = run.get();
     151 + 
     152 + assertThat(run.isDone()).isTrue();
     153 + assertRunResponseContainsAllRunRequestParameters(runResponse, runRequest);
     154 + }
     155 + 
     156 + @Test
     157 + public void run_multiplePluginValidRequest_returnMultipleDetectionReports() throws Exception {
     158 + int numPluginsToTest = 5;
     159 + 
     160 + List<NetworkEndpoint> endpoints = new ArrayList<>(numPluginsToTest);
     161 + endpoints.add(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 80));
     162 + endpoints.add(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 443));
     163 + endpoints.add(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 123));
     164 + endpoints.add(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 456));
     165 + endpoints.add(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 789));
     166 + 
     167 + PluginInfo.Builder pluginInfoBuilder =
     168 + PluginInfo.newBuilder()
     169 + .setType(PluginInfo.PluginType.VULN_DETECTION)
     170 + .setVersion(PLUGIN_VERSION)
     171 + .setDescription(PLUGIN_DESCRIPTION)
     172 + .setAuthor(PLUGIN_AUTHOR);
     173 + 
     174 + TargetInfo target = TargetInfo.newBuilder().addAllNetworkEndpoints(endpoints).build();
     175 + 
     176 + RunRequest.Builder runRequestBuilder = RunRequest.newBuilder().setTarget(target);
     177 + 
     178 + for (int i = 0; i < numPluginsToTest; i++) {
     179 + PluginInfo pluginInfo =
     180 + pluginInfoBuilder.setName(String.format(PLUGIN_NAME + " %d", i)).build();
     181 + NetworkService httpService =
     182 + NetworkService.newBuilder()
     183 + .setNetworkEndpoint(endpoints.get(i))
     184 + .setTransportProtocol(TransportProtocol.TCP)
     185 + .setServiceName("http")
     186 + .build();
     187 + runRequestBuilder.addPlugins(
     188 + MatchedPlugin.newBuilder()
     189 + .addServices(httpService)
     190 + .setPlugin(PluginDefinition.newBuilder().setInfo(pluginInfo).build()));
     191 + }
     192 + RunRequest runRequest = runRequestBuilder.build();
     193 + 
     194 + PluginServiceImplBase runImpl =
     195 + new PluginServiceImplBase() {
     196 + @Override
     197 + public void run(RunRequest request, StreamObserver<RunResponse> responseObserver) {
     198 + DetectionReportList.Builder reportListBuilder = DetectionReportList.newBuilder();
     199 + for (MatchedPlugin plugin : request.getPluginsList()) {
     200 + reportListBuilder.addDetectionReports(
     201 + DetectionReport.newBuilder()
     202 + .setTargetInfo(request.getTarget())
     203 + .setNetworkService(plugin.getServices(0)));
     204 + }
     205 + responseObserver.onNext(RunResponse.newBuilder().setReports(reportListBuilder).build());
     206 + responseObserver.onCompleted();
     207 + }
     208 + };
     209 + serviceRegistry.addService(runImpl);
     210 + 
     211 + ListenableFuture<RunResponse> run = pluginService.runWithDeadline(runRequest, DURATION_DEFAULT);
     212 + RunResponse runResponse = run.get();
     213 + 
     214 + assertThat(run.isDone()).isTrue();
     215 + assertThat(runResponse.getReports().getDetectionReportsCount()).isEqualTo(numPluginsToTest);
     216 + assertRunResponseContainsAllRunRequestParameters(runResponse, runRequest);
     217 + }
     218 + 
     219 + private void assertRunResponseContainsAllRunRequestParameters(
     220 + RunResponse response, RunRequest request) throws Exception {
     221 + for (MatchedPlugin plugin : request.getPluginsList()) {
     222 + DetectionReport expectedReport =
     223 + DetectionReport.newBuilder()
     224 + .setTargetInfo(request.getTarget())
     225 + .setNetworkService(plugin.getServices(0))
     226 + .build();
     227 + assertThat(response.getReports().getDetectionReportsList()).contains(expectedReport);
     228 + }
     229 + }
     230 + 
     231 + private PluginDefinition createSinglePluginDefinitionWithName(String name) {
     232 + PluginInfo pluginInfo =
     233 + PluginInfo.newBuilder()
     234 + .setType(PluginInfo.PluginType.VULN_DETECTION)
     235 + .setName(name)
     236 + .setVersion(PLUGIN_VERSION)
     237 + .setDescription(PLUGIN_DESCRIPTION)
     238 + .setAuthor(PLUGIN_AUTHOR)
     239 + .build();
     240 + return PluginDefinition.newBuilder().setInfo(pluginInfo).build();
     241 + }
     242 + 
     243 + private RunRequest createSinglePluginRunRequest() {
     244 + PluginDefinition singlePlugin = createSinglePluginDefinitionWithName(PLUGIN_NAME);
     245 + NetworkService httpService =
     246 + NetworkService.newBuilder()
     247 + .setNetworkEndpoint(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 80))
     248 + .setTransportProtocol(TransportProtocol.TCP)
     249 + .setServiceName("http")
     250 + .build();
     251 + TargetInfo target =
     252 + TargetInfo.newBuilder().addNetworkEndpoints(httpService.getNetworkEndpoint()).build();
     253 + 
     254 + return RunRequest.newBuilder()
     255 + .setTarget(target)
     256 + .addPlugins(MatchedPlugin.newBuilder().addServices(httpService).setPlugin(singlePlugin))
     257 + .build();
     258 + }
     259 + // TODO(b/236740807): Add test case for errors related to RPC calls once wrapper CL is done.
    82 260  }
    83 261   
Please wait...
Page is in error, reload to recover